Merge branch 'main' into allow-dynamic-models-ollama

This commit is contained in:
Matthew Farrellee 2025-07-21 05:17:29 -04:00
commit c67bae2d07
145 changed files with 6481 additions and 5159 deletions

View file

@ -4,3 +4,9 @@ omit =
*/llama_stack/providers/* */llama_stack/providers/*
*/llama_stack/templates/* */llama_stack/templates/*
.venv/* .venv/*
*/llama_stack/cli/scripts/*
*/llama_stack/ui/*
*/llama_stack/distribution/ui/*
*/llama_stack/strong_typing/*
*/llama_stack/env.py
*/__init__.py

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in # These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence, # the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist @mattf * @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist @mattf @slekkala1

30
.github/ISSUE_TEMPLATE/tech-debt.yml vendored Normal file
View file

@ -0,0 +1,30 @@
name: 🔧 Tech Debt
description: Something that is functional but should be improved or optimizied
labels: ["tech-debt"]
body:
- type: textarea
id: tech-debt-explanation
attributes:
label: 🤔 What is the technical debt you think should be addressed?
description: >
A clear and concise description of _what_ needs to be addressed - ensure you are describing
constitutes [technical debt](https://en.wikipedia.org/wiki/Technical_debt) and is not a bug
or feature request.
validations:
required: true
- type: textarea
id: tech-debt-motivation
attributes:
label: 💡 What is the benefit of addressing this technical debt?
description: >
A clear and concise description of _why_ this work is needed.
validations:
required: true
- type: textarea
id: other-thoughts
attributes:
label: Other thoughts
description: >
Any thoughts about how this may result in complexity in the codebase, or other trade-offs.

View file

@ -7,7 +7,5 @@ runs:
shell: bash shell: bash
run: | run: |
docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models
# TODO: rebuild an ollama image with llama-guard3:1b
echo "Verifying Ollama status..." echo "Verifying Ollama status..."
timeout 30 bash -c 'while ! curl -s -L http://127.0.0.1:11434; do sleep 1 && echo "."; done' timeout 30 bash -c 'while ! curl -s -L http://127.0.0.1:11434; do sleep 1 && echo "."; done'
docker exec ollama ollama pull llama-guard3:1b

View file

@ -5,6 +5,10 @@ inputs:
description: The Python version to use description: The Python version to use
required: false required: false
default: "3.12" default: "3.12"
client-version:
description: The llama-stack-client-python version to test against (latest or published)
required: false
default: "latest"
runs: runs:
using: "composite" using: "composite"
steps: steps:
@ -20,8 +24,17 @@ runs:
run: | run: |
uv sync --all-groups uv sync --all-groups
uv pip install ollama faiss-cpu uv pip install ollama faiss-cpu
# always test against the latest version of the client
# TODO: this is not necessarily a good idea. we need to test against both published and latest # Install llama-stack-client-python based on the client-version input
# to find out backwards compatibility issues. if [ "${{ inputs.client-version }}" = "latest" ]; then
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main echo "Installing latest llama-stack-client-python from main branch"
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
elif [ "${{ inputs.client-version }}" = "published" ]; then
echo "Installing published llama-stack-client-python from PyPI"
uv pip install llama-stack-client
else
echo "Invalid client-version: ${{ inputs.client-version }}"
exit 1
fi
uv pip install -e . uv pip install -e .

57
.github/workflows/coverage-badge.yml vendored Normal file
View file

@ -0,0 +1,57 @@
name: Coverage Badge
on:
push:
branches: [ main ]
paths:
- 'llama_stack/**'
- 'tests/unit/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/unit-tests.yml'
- '.github/workflows/coverage-badge.yml' # This workflow
workflow_dispatch:
jobs:
unit-tests:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Run unit tests
run: |
./scripts/unit-tests.sh
- name: Coverage Badge
uses: tj-actions/coverage-badge-py@1788babcb24544eb5bbb6e0d374df5d1e54e670f # v2.0.4
- name: Verify Changed files
uses: tj-actions/verify-changed-files@a1c6acee9df209257a246f2cc6ae8cb6581c1edf # v20.0.4
id: verify-changed-files
with:
files: coverage.svg
- name: Commit files
if: steps.verify-changed-files.outputs.files_changed == 'true'
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git add coverage.svg
git commit -m "Updated coverage.svg"
- name: Create Pull Request
if: steps.verify-changed-files.outputs.files_changed == 'true'
uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8
with:
token: ${{ secrets.GITHUB_TOKEN }}
title: "ci: [Automatic] Coverage Badge Update"
body: |
This PR updates the coverage badge based on the latest coverage report.
Automatically generated by the [workflow coverage-badge.yaml](.github/workflows/coverage-badge.yaml)
delete-branch: true

View file

@ -1,355 +0,0 @@
name: "Run Llama-stack Tests"
on:
#### Temporarily disable PR runs until tests run as intended within mainline.
#TODO Add this back.
#pull_request_target:
# types: ["opened"]
# branches:
# - 'main'
# paths:
# - 'llama_stack/**/*.py'
# - 'tests/**/*.py'
workflow_dispatch:
inputs:
runner:
description: 'GHA Runner Scale Set label to run workflow on.'
required: true
default: "llama-stack-gha-runner-gpu"
checkout_reference:
description: "The branch, tag, or SHA to checkout"
required: true
default: "main"
debug:
description: 'Run debugging steps?'
required: false
default: "true"
sleep_time:
description: '[DEBUG] sleep time for debugging'
required: true
default: "0"
provider_id:
description: 'ID of your provider'
required: true
default: "meta_reference"
model_id:
description: 'Shorthand name for target model ID (llama_3b or llama_8b)'
required: true
default: "llama_3b"
model_override_3b:
description: 'Specify shorthand model for <llama_3b> '
required: false
default: "Llama3.2-3B-Instruct"
model_override_8b:
description: 'Specify shorthand model for <llama_8b> '
required: false
default: "Llama3.1-8B-Instruct"
env:
# ID used for each test's provider config
PROVIDER_ID: "${{ inputs.provider_id || 'meta_reference' }}"
# Path to model checkpoints within EFS volume
MODEL_CHECKPOINT_DIR: "/data/llama"
# Path to directory to run tests from
TESTS_PATH: "${{ github.workspace }}/llama_stack/providers/tests"
# Keep track of a list of model IDs that are valid to use within pytest fixture marks
AVAILABLE_MODEL_IDs: "llama_3b llama_8b"
# Shorthand name for model ID, used in pytest fixture marks
MODEL_ID: "${{ inputs.model_id || 'llama_3b' }}"
# Override the `llama_3b` / `llama_8b' models, else use the default.
LLAMA_3B_OVERRIDE: "${{ inputs.model_override_3b || 'Llama3.2-3B-Instruct' }}"
LLAMA_8B_OVERRIDE: "${{ inputs.model_override_8b || 'Llama3.1-8B-Instruct' }}"
# Defines which directories in TESTS_PATH to exclude from the test loop
EXCLUDED_DIRS: "__pycache__"
# Defines the output xml reports generated after a test is run
REPORTS_GEN: ""
jobs:
execute_workflow:
name: Execute workload on Self-Hosted GPU k8s runner
permissions:
pull-requests: write
defaults:
run:
shell: bash
runs-on: ${{ inputs.runner != '' && inputs.runner || 'llama-stack-gha-runner-gpu' }}
if: always()
steps:
##############################
#### INITIAL DEBUG CHECKS ####
##############################
- name: "[DEBUG] Check content of the EFS mount"
id: debug_efs_volume
continue-on-error: true
if: inputs.debug == 'true'
run: |
echo "========= Content of the EFS mount ============="
ls -la ${{ env.MODEL_CHECKPOINT_DIR }}
- name: "[DEBUG] Get runner container OS information"
id: debug_os_info
if: ${{ inputs.debug == 'true' }}
run: |
cat /etc/os-release
- name: "[DEBUG] Print environment variables"
id: debug_env_vars
if: ${{ inputs.debug == 'true' }}
run: |
echo "PROVIDER_ID = ${PROVIDER_ID}"
echo "MODEL_CHECKPOINT_DIR = ${MODEL_CHECKPOINT_DIR}"
echo "AVAILABLE_MODEL_IDs = ${AVAILABLE_MODEL_IDs}"
echo "MODEL_ID = ${MODEL_ID}"
echo "LLAMA_3B_OVERRIDE = ${LLAMA_3B_OVERRIDE}"
echo "LLAMA_8B_OVERRIDE = ${LLAMA_8B_OVERRIDE}"
echo "EXCLUDED_DIRS = ${EXCLUDED_DIRS}"
echo "REPORTS_GEN = ${REPORTS_GEN}"
############################
#### MODEL INPUT CHECKS ####
############################
- name: "Check if env.model_id is valid"
id: check_model_id
run: |
if [[ " ${AVAILABLE_MODEL_IDs[@]} " =~ " ${MODEL_ID} " ]]; then
echo "Model ID '${MODEL_ID}' is valid."
else
echo "Model ID '${MODEL_ID}' is invalid. Terminating workflow."
exit 1
fi
#######################
#### CODE CHECKOUT ####
#######################
- name: "Checkout 'meta-llama/llama-stack' repository"
id: checkout_repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
ref: ${{ inputs.branch }}
- name: "[DEBUG] Content of the repository after checkout"
id: debug_content_after_checkout
if: ${{ inputs.debug == 'true' }}
run: |
ls -la ${GITHUB_WORKSPACE}
##########################################################
#### OPTIONAL SLEEP DEBUG ####
# #
# Use to "exec" into the test k8s POD and run tests #
# manually to identify what dependencies are being used. #
# #
##########################################################
- name: "[DEBUG] sleep"
id: debug_sleep
if: ${{ inputs.debug == 'true' && inputs.sleep_time != '' }}
run: |
sleep ${{ inputs.sleep_time }}
############################
#### UPDATE SYSTEM PATH ####
############################
- name: "Update path: execute"
id: path_update_exec
run: |
# .local/bin is needed for certain libraries installed below to be recognized
# when calling their executable to install sub-dependencies
mkdir -p ${HOME}/.local/bin
echo "${HOME}/.local/bin" >> "$GITHUB_PATH"
#####################################
#### UPDATE CHECKPOINT DIRECTORY ####
#####################################
- name: "Update checkpoint directory"
id: checkpoint_update
run: |
echo "Checkpoint directory: ${MODEL_CHECKPOINT_DIR}/$LLAMA_3B_OVERRIDE"
if [ "${MODEL_ID}" = "llama_3b" ] && [ -d "${MODEL_CHECKPOINT_DIR}/${LLAMA_3B_OVERRIDE}" ]; then
echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/${LLAMA_3B_OVERRIDE}" >> "$GITHUB_ENV"
elif [ "${MODEL_ID}" = "llama_8b" ] && [ -d "${MODEL_CHECKPOINT_DIR}/${LLAMA_8B_OVERRIDE}" ]; then
echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/${LLAMA_8B_OVERRIDE}" >> "$GITHUB_ENV"
else
echo "MODEL_ID & LLAMA_*B_OVERRIDE are not a valid pairing. Terminating workflow."
exit 1
fi
- name: "[DEBUG] Checkpoint update check"
id: debug_checkpoint_update
if: ${{ inputs.debug == 'true' }}
run: |
echo "MODEL_CHECKPOINT_DIR (after update) = ${MODEL_CHECKPOINT_DIR}"
##################################
#### DEPENDENCY INSTALLATIONS ####
##################################
- name: "Installing 'apt' required packages"
id: install_apt
run: |
echo "[STEP] Installing 'apt' required packages"
sudo apt update -y
sudo apt install -y python3 python3-pip npm wget
- name: "Installing packages with 'curl'"
id: install_curl
run: |
curl -fsSL https://ollama.com/install.sh | sh
- name: "Installing packages with 'wget'"
id: install_wget
run: |
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
chmod +x Miniconda3-latest-Linux-x86_64.sh
./Miniconda3-latest-Linux-x86_64.sh -b install -c pytorch -c nvidia faiss-gpu=1.9.0
# Add miniconda3 bin to system path
echo "${HOME}/miniconda3/bin" >> "$GITHUB_PATH"
- name: "Installing packages with 'npm'"
id: install_npm_generic
run: |
sudo npm install -g junit-merge
- name: "Installing pip dependencies"
id: install_pip_generic
run: |
echo "[STEP] Installing 'llama-stack' models"
pip install -U pip setuptools
pip install -r requirements.txt
pip install -e .
pip install -U \
torch torchvision \
pytest pytest_asyncio \
fairscale lm-format-enforcer \
zmq chardet pypdf \
pandas sentence_transformers together \
aiosqlite
- name: "Installing packages with conda"
id: install_conda_generic
run: |
conda install -q -c pytorch -c nvidia faiss-gpu=1.9.0
#############################################################
#### TESTING TO BE DONE FOR BOTH PRS AND MANUAL DISPATCH ####
#############################################################
- name: "Run Tests: Loop"
id: run_tests_loop
working-directory: "${{ github.workspace }}"
run: |
pattern=""
for dir in llama_stack/providers/tests/*; do
if [ -d "$dir" ]; then
dir_name=$(basename "$dir")
if [[ ! " $EXCLUDED_DIRS " =~ " $dir_name " ]]; then
for file in "$dir"/test_*.py; do
test_name=$(basename "$file")
new_file="result-${dir_name}-${test_name}.xml"
if torchrun $(which pytest) -s -v ${TESTS_PATH}/${dir_name}/${test_name} -m "${PROVIDER_ID} and ${MODEL_ID}" \
--junitxml="${{ github.workspace }}/${new_file}"; then
echo "Ran test: ${test_name}"
else
echo "Did NOT run test: ${test_name}"
fi
pattern+="${new_file} "
done
fi
fi
done
echo "REPORTS_GEN=$pattern" >> "$GITHUB_ENV"
- name: "Test Summary: Merge"
id: test_summary_merge
working-directory: "${{ github.workspace }}"
run: |
echo "Merging the following test result files: ${REPORTS_GEN}"
# Defaults to merging them into 'merged-test-results.xml'
junit-merge ${{ env.REPORTS_GEN }}
############################################
#### AUTOMATIC TESTING ON PULL REQUESTS ####
############################################
#### Run tests ####
- name: "PR - Run Tests"
id: pr_run_tests
working-directory: "${{ github.workspace }}"
if: github.event_name == 'pull_request_target'
run: |
echo "[STEP] Running PyTest tests at 'GITHUB_WORKSPACE' path: ${GITHUB_WORKSPACE} | path: ${{ github.workspace }}"
# (Optional) Add more tests here.
# Merge test results with 'merged-test-results.xml' from above.
# junit-merge <new-test-results> merged-test-results.xml
#### Create test summary ####
- name: "PR - Test Summary"
id: pr_test_summary_create
if: github.event_name == 'pull_request_target'
uses: test-summary/action@31493c76ec9e7aa675f1585d3ed6f1da69269a86 # v2.4
with:
paths: "${{ github.workspace }}/merged-test-results.xml"
output: test-summary.md
- name: "PR - Upload Test Summary"
id: pr_test_summary_upload
if: github.event_name == 'pull_request_target'
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: test-summary
path: test-summary.md
#### Update PR request ####
- name: "PR - Update comment"
id: pr_update_comment
if: github.event_name == 'pull_request_target'
uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b # v3.0.1
with:
filePath: test-summary.md
########################
#### MANUAL TESTING ####
########################
#### Run tests ####
- name: "Manual - Run Tests: Prep"
id: manual_run_tests
working-directory: "${{ github.workspace }}"
if: github.event_name == 'workflow_dispatch'
run: |
echo "[STEP] Running PyTest tests at 'GITHUB_WORKSPACE' path: ${{ github.workspace }}"
#TODO Use this when collection errors are resolved
# pytest -s -v -m "${PROVIDER_ID} and ${MODEL_ID}" --junitxml="${{ github.workspace }}/merged-test-results.xml"
# (Optional) Add more tests here.
# Merge test results with 'merged-test-results.xml' from above.
# junit-merge <new-test-results> merged-test-results.xml
#### Create test summary ####
- name: "Manual - Test Summary"
id: manual_test_summary
if: always() && github.event_name == 'workflow_dispatch'
uses: test-summary/action@31493c76ec9e7aa675f1585d3ed6f1da69269a86 # v2.4
with:
paths: "${{ github.workspace }}/merged-test-results.xml"

View file

@ -7,11 +7,20 @@ on:
branches: [ main ] branches: [ main ]
paths: paths:
- 'llama_stack/**' - 'llama_stack/**'
- 'tests/integration/**' - 'tests/**'
- 'uv.lock' - 'uv.lock'
- 'pyproject.toml' - 'pyproject.toml'
- 'requirements.txt' - 'requirements.txt'
- '.github/workflows/integration-tests.yml' # This workflow - '.github/workflows/integration-tests.yml' # This workflow
- '.github/actions/setup-ollama/action.yml'
schedule:
- cron: '0 0 * * *' # Daily at 12 AM UTC
workflow_dispatch:
inputs:
test-all-client-versions:
description: 'Test against both the latest and published versions'
type: boolean
default: false
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref }}
@ -45,6 +54,7 @@ jobs:
test-type: ${{ fromJson(needs.discover-tests.outputs.test-type) }} test-type: ${{ fromJson(needs.discover-tests.outputs.test-type) }}
client-type: [library, server] client-type: [library, server]
python-version: ["3.12", "3.13"] python-version: ["3.12", "3.13"]
client-version: ${{ (github.event_name == 'schedule' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -54,13 +64,14 @@ jobs:
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
client-version: ${{ matrix.client-version }}
- name: Setup ollama - name: Setup ollama
uses: ./.github/actions/setup-ollama uses: ./.github/actions/setup-ollama
- name: Build Llama Stack - name: Build Llama Stack
run: | run: |
uv run llama stack build --template starter --image-type venv uv run llama stack build --template ci-tests --image-type venv
- name: Check Storage and Memory Available Before Tests - name: Check Storage and Memory Available Before Tests
if: ${{ always() }} if: ${{ always() }}
@ -81,15 +92,15 @@ jobs:
shell: bash shell: bash
run: | run: |
if [ "${{ matrix.client-type }}" == "library" ]; then if [ "${{ matrix.client-type }}" == "library" ]; then
stack_config="starter" stack_config="ci-tests"
else else
stack_config="server:starter" stack_config="server:ci-tests"
fi fi
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="ollama/llama3.2:3b-instruct-fp16" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \
--embedding-model=all-MiniLM-L6-v2 \ --embedding-model=all-MiniLM-L6-v2 \
--safety-shield=ollama \ --safety-shield=$SAFETY_MODEL \
--color=yes \ --color=yes \
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log
@ -108,7 +119,7 @@ jobs:
if: ${{ always() }} if: ${{ always() }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with: with:
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}-${{ matrix.python-version }} name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}-${{ matrix.python-version }}-${{ matrix.client-version }}
path: | path: |
*.log *.log
retention-days: 1 retention-days: 1

View file

@ -93,7 +93,7 @@ jobs:
- name: Build Llama Stack - name: Build Llama Stack
run: | run: |
uv run llama stack build --template starter --image-type venv uv run llama stack build --template ci-tests --image-type venv
- name: Check Storage and Memory Available Before Tests - name: Check Storage and Memory Available Before Tests
if: ${{ always() }} if: ${{ always() }}

View file

@ -97,9 +97,9 @@ jobs:
- name: Build a single provider - name: Build a single provider
run: | run: |
yq -i '.image_type = "container"' llama_stack/templates/starter/build.yaml yq -i '.image_type = "container"' llama_stack/templates/ci-tests/build.yaml
yq -i '.image_name = "test"' llama_stack/templates/starter/build.yaml yq -i '.image_name = "test"' llama_stack/templates/ci-tests/build.yaml
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/starter/build.yaml USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/ci-tests/build.yaml
- name: Inspect the container image entrypoint - name: Inspect the container image entrypoint
run: | run: |
@ -126,14 +126,14 @@ jobs:
.image_type = "container" | .image_type = "container" |
.image_name = "ubi9-test" | .image_name = "ubi9-test" |
.distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest" .distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest"
' llama_stack/templates/starter/build.yaml ' llama_stack/templates/ci-tests/build.yaml
- name: Build dev container (UBI9) - name: Build dev container (UBI9)
env: env:
USE_COPY_NOT_MOUNT: "true" USE_COPY_NOT_MOUNT: "true"
LLAMA_STACK_DIR: "." LLAMA_STACK_DIR: "."
run: | run: |
uv run llama stack build --config llama_stack/templates/starter/build.yaml uv run llama stack build --config llama_stack/templates/ci-tests/build.yaml
- name: Inspect UBI9 image - name: Inspect UBI9 image
run: | run: |

View file

@ -20,7 +20,7 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1 uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f # v6.4.1
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
activate-environment: true activate-environment: true

View file

@ -1,69 +0,0 @@
name: auto-tests
on:
# pull_request:
workflow_dispatch:
inputs:
commit_sha:
description: 'Specific Commit SHA to trigger on'
required: false
default: $GITHUB_SHA # default to the last commit of $GITHUB_REF branch
jobs:
test-llama-stack-as-library:
runs-on: ubuntu-latest
env:
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
TAVILY_SEARCH_API_KEY: ${{ secrets.TAVILY_SEARCH_API_KEY }}
strategy:
matrix:
provider: [fireworks, together]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
ref: ${{ github.event.inputs.commit_sha }}
- name: Echo commit SHA
run: |
echo "Triggered on commit SHA: ${{ github.event.inputs.commit_sha }}"
git rev-parse HEAD
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt pytest
pip install -e .
- name: Build providers
run: |
llama stack build --template ${{ matrix.provider }} --image-type venv
- name: Install the latest llama-stack-client & llama-models packages
run: |
pip install -e git+https://github.com/meta-llama/llama-stack-client-python.git#egg=llama-stack-client
pip install -e git+https://github.com/meta-llama/llama-models.git#egg=llama-models
- name: Run client-sdk test
working-directory: "${{ github.workspace }}"
env:
REPORT_OUTPUT: md_report.md
shell: bash
run: |
pip install --upgrade pytest-md-report
echo "REPORT_FILE=${REPORT_OUTPUT}" >> "$GITHUB_ENV"
export INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
LLAMA_STACK_CONFIG=./llama_stack/templates/${{ matrix.provider }}/run.yaml pytest --md-report --md-report-verbose=1 ./tests/client-sdk/inference/ --md-report-output "$REPORT_OUTPUT"
- name: Output reports to the job summary
if: always()
shell: bash
run: |
if [ -f "$REPORT_FILE" ]; then
echo "<details><summary> Test Report for ${{ matrix.provider }} </summary>" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
cat "$REPORT_FILE" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "</details>" >> $GITHUB_STEP_SUMMARY
fi

View file

@ -36,7 +36,7 @@ jobs:
- name: Run unit tests - name: Run unit tests
run: | run: |
PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --cov=llama_stack --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }} PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --junitxml=pytest-report-${{ matrix.python }}.xml
- name: Upload test results - name: Upload test results
if: always() if: always()

View file

@ -129,6 +129,22 @@ repos:
require_serial: true require_serial: true
always_run: true always_run: true
files: ^llama_stack/.*$ files: ^llama_stack/.*$
- id: forbid-pytest-asyncio
name: Block @pytest.mark.asyncio and @pytest_asyncio.fixture
entry: bash
language: system
types: [python]
pass_filenames: true
args:
- -c
- |
grep -EnH '^[^#]*@pytest\.mark\.asyncio|@pytest_asyncio\.fixture' "$@" && {
echo;
echo "❌ Do not use @pytest.mark.asyncio or @pytest_asyncio.fixture."
echo " pytest is already configured with async-mode=auto."
echo;
exit 1;
} || true
ci: ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -112,7 +112,7 @@ uv run pre-commit run --all-files
## Running tests ## Running tests
You can find the Llama Stack testing documentation here [here](tests/README.md). You can find the Llama Stack testing documentation [here](https://github.com/meta-llama/llama-stack/blob/main/tests/README.md).
## Adding a new dependency to the project ## Adding a new dependency to the project

View file

@ -6,6 +6,7 @@
[![Discord](https://img.shields.io/discord/1257833999603335178?color=6A7EC2&logo=discord&logoColor=ffffff)](https://discord.gg/llama-stack) [![Discord](https://img.shields.io/discord/1257833999603335178?color=6A7EC2&logo=discord&logoColor=ffffff)](https://discord.gg/llama-stack)
[![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain) [![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
[![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain) [![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
![coverage badge](./coverage.svg)
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack) [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)

21
coverage.svg Normal file
View file

@ -0,0 +1,21 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" width="99" height="20">
<linearGradient id="b" x2="0" y2="100%">
<stop offset="0" stop-color="#bbb" stop-opacity=".1"/>
<stop offset="1" stop-opacity=".1"/>
</linearGradient>
<mask id="a">
<rect width="99" height="20" rx="3" fill="#fff"/>
</mask>
<g mask="url(#a)">
<path fill="#555" d="M0 0h63v20H0z"/>
<path fill="#fe7d37" d="M63 0h36v20H63z"/>
<path fill="url(#b)" d="M0 0h99v20H0z"/>
</g>
<g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="11">
<text x="31.5" y="15" fill="#010101" fill-opacity=".3">coverage</text>
<text x="31.5" y="14">coverage</text>
<text x="80" y="15" fill="#010101" fill-opacity=".3">44%</text>
<text x="80" y="14">44%</text>
</g>
</svg>

After

Width:  |  Height:  |  Size: 904 B

View file

@ -11340,6 +11340,9 @@
}, },
"embedding_dimension": { "embedding_dimension": {
"type": "integer" "type": "integer"
},
"vector_db_name": {
"type": "string"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -13590,10 +13593,6 @@
"provider_id": { "provider_id": {
"type": "string", "type": "string",
"description": "The ID of the provider to use for this vector store." "description": "The ID of the provider to use for this vector store."
},
"provider_vector_db_id": {
"type": "string",
"description": "The provider-specific vector database ID."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -14471,28 +14470,31 @@
"DPOAlignmentConfig": { "DPOAlignmentConfig": {
"type": "object", "type": "object",
"properties": { "properties": {
"reward_scale": { "beta": {
"type": "number" "type": "number"
}, },
"reward_clip": { "loss_type": {
"type": "number" "$ref": "#/components/schemas/DPOLossType",
}, "default": "sigmoid"
"epsilon": {
"type": "number"
},
"gamma": {
"type": "number"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"reward_scale", "beta",
"reward_clip", "loss_type"
"epsilon",
"gamma"
], ],
"title": "DPOAlignmentConfig" "title": "DPOAlignmentConfig"
}, },
"DPOLossType": {
"type": "string",
"enum": [
"sigmoid",
"hinge",
"ipo",
"kto_pair"
],
"title": "DPOLossType"
},
"DataConfig": { "DataConfig": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -15634,6 +15636,10 @@
"type": "string", "type": "string",
"description": "The identifier of the provider." "description": "The identifier of the provider."
}, },
"vector_db_name": {
"type": "string",
"description": "The name of the vector database."
},
"provider_vector_db_id": { "provider_vector_db_id": {
"type": "string", "type": "string",
"description": "The identifier of the vector database in the provider." "description": "The identifier of the vector database in the provider."

View file

@ -7984,6 +7984,8 @@ components:
type: string type: string
embedding_dimension: embedding_dimension:
type: integer type: integer
vector_db_name:
type: string
additionalProperties: false additionalProperties: false
required: required:
- identifier - identifier
@ -9494,10 +9496,6 @@ components:
type: string type: string
description: >- description: >-
The ID of the provider to use for this vector store. The ID of the provider to use for this vector store.
provider_vector_db_id:
type: string
description: >-
The provider-specific vector database ID.
additionalProperties: false additionalProperties: false
required: required:
- name - name
@ -10113,21 +10111,24 @@ components:
DPOAlignmentConfig: DPOAlignmentConfig:
type: object type: object
properties: properties:
reward_scale: beta:
type: number
reward_clip:
type: number
epsilon:
type: number
gamma:
type: number type: number
loss_type:
$ref: '#/components/schemas/DPOLossType'
default: sigmoid
additionalProperties: false additionalProperties: false
required: required:
- reward_scale - beta
- reward_clip - loss_type
- epsilon
- gamma
title: DPOAlignmentConfig title: DPOAlignmentConfig
DPOLossType:
type: string
enum:
- sigmoid
- hinge
- ipo
- kto_pair
title: DPOLossType
DataConfig: DataConfig:
type: object type: object
properties: properties:
@ -10945,6 +10946,9 @@ components:
provider_id: provider_id:
type: string type: string
description: The identifier of the provider. description: The identifier of the provider.
vector_db_name:
type: string
description: The name of the vector database.
provider_vector_db_id: provider_vector_db_id:
type: string type: string
description: >- description: >-

View file

@ -0,0 +1,6 @@
# Eval Providers
This section contains documentation for all available providers for the **eval** API.
- [inline::meta-reference](inline_meta-reference.md)
- [remote::nvidia](remote_nvidia.md)

View file

@ -0,0 +1,21 @@
# inline::meta-reference
## Description
Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | |
## Sample Configuration
```yaml
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/meta_reference_eval.db
```

View file

@ -0,0 +1,19 @@
# remote::nvidia
## Description
NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `evaluator_url` | `<class 'str'>` | No | http://0.0.0.0:7331 | The url for accessing the evaluator service |
## Sample Configuration
```yaml
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}
```

View file

@ -0,0 +1,33 @@
# Advanced APIs
## Post-training
Fine-tunes a model.
```{toctree}
:maxdepth: 1
post_training/index
```
## Eval
Generates outputs (via Inference or Agents) and perform scoring.
```{toctree}
:maxdepth: 1
eval/index
```
```{include} evaluation_concepts.md
:start-after: ## Evaluation Concepts
```
## Scoring
Evaluates the outputs of the system.
```{toctree}
:maxdepth: 1
scoring/index
```

View file

@ -0,0 +1,7 @@
# Post_Training Providers
This section contains documentation for all available providers for the **post_training** API.
- [inline::huggingface](inline_huggingface.md)
- [inline::torchtune](inline_torchtune.md)
- [remote::nvidia](remote_nvidia.md)

View file

@ -0,0 +1,33 @@
# inline::huggingface
## Description
HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `device` | `<class 'str'>` | No | cuda | |
| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | |
| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | |
| `chat_template` | `<class 'str'>` | No | |
| `model_specific_config` | `<class 'dict'>` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
| `max_seq_length` | `<class 'int'>` | No | 2048 | |
| `gradient_checkpointing` | `<class 'bool'>` | No | False | |
| `save_total_limit` | `<class 'int'>` | No | 3 | |
| `logging_steps` | `<class 'int'>` | No | 10 | |
| `warmup_ratio` | `<class 'float'>` | No | 0.1 | |
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
## Sample Configuration
```yaml
checkpoint_format: huggingface
distributed_backend: null
device: cpu
```

View file

@ -0,0 +1,20 @@
# inline::torchtune
## Description
TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `torch_seed` | `int \| None` | No | | |
| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | |
## Sample Configuration
```yaml
checkpoint_format: meta
```

View file

@ -0,0 +1,28 @@
# remote::nvidia
## Description
NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `api_key` | `str \| None` | No | | The NVIDIA API key. |
| `dataset_namespace` | `str \| None` | No | default | The NVIDIA dataset namespace. |
| `project_id` | `str \| None` | No | test-example-model@v1 | The NVIDIA project ID. |
| `customizer_url` | `str \| None` | No | | Base URL for the NeMo Customizer API |
| `timeout` | `<class 'int'>` | No | 300 | Timeout for the NVIDIA Post Training API |
| `max_retries` | `<class 'int'>` | No | 3 | Maximum number of retries for the NVIDIA Post Training API |
| `output_model_dir` | `<class 'str'>` | No | test-example-model@v1 | Directory to save the output model |
## Sample Configuration
```yaml
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
```

View file

@ -0,0 +1,7 @@
# Scoring Providers
This section contains documentation for all available providers for the **scoring** API.
- [inline::basic](inline_basic.md)
- [inline::braintrust](inline_braintrust.md)
- [inline::llm-as-judge](inline_llm-as-judge.md)

View file

@ -0,0 +1,13 @@
# inline::basic
## Description
Basic scoring provider for simple evaluation metrics and scoring functions.
## Sample Configuration
```yaml
{}
```

View file

@ -0,0 +1,19 @@
# inline::braintrust
## Description
Braintrust scoring provider for evaluation and scoring using the Braintrust platform.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `openai_api_key` | `str \| None` | No | | The OpenAI API Key |
## Sample Configuration
```yaml
openai_api_key: ${env.OPENAI_API_KEY:=}
```

View file

@ -0,0 +1,13 @@
# inline::llm-as-judge
## Description
LLM-as-judge scoring provider that uses language models to evaluate and score responses.
## Sample Configuration
```yaml
{}
```

View file

@ -1,4 +1,4 @@
# Building AI Applications (Examples) # AI Application Examples
Llama Stack provides all the building blocks needed to create sophisticated AI applications. Llama Stack provides all the building blocks needed to create sophisticated AI applications.
@ -27,4 +27,5 @@ tools
evals evals
telemetry telemetry
safety safety
playground/index
``` ```

View file

@ -1,4 +1,4 @@
# Llama Stack Playground ## Llama Stack Playground
```{note} ```{note}
The Llama Stack Playground is currently experimental and subject to change. We welcome feedback and contributions to help improve it. The Llama Stack Playground is currently experimental and subject to change. We welcome feedback and contributions to help improve it.
@ -9,7 +9,7 @@ The Llama Stack Playground is an simple interface which aims to:
- Demo **end-to-end** application code to help users get started to build their own applications - Demo **end-to-end** application code to help users get started to build their own applications
- Provide an **UI** to help users inspect and understand Llama Stack API providers and resources - Provide an **UI** to help users inspect and understand Llama Stack API providers and resources
## Key Features ### Key Features
#### Playground #### Playground
Interactive pages for users to play with and explore Llama Stack API capabilities. Interactive pages for users to play with and explore Llama Stack API capabilities.
@ -90,7 +90,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
- Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources. - Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources.
- Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources. - Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources.
## Starting the Llama Stack Playground ### Starting the Llama Stack Playground
To start the Llama Stack Playground, run the following commands: To start the Llama Stack Playground, run the following commands:

View file

@ -1,31 +1,39 @@
# Why Llama Stack? ## Llama Stack architecture
Building production AI applications today requires solving multiple challenges: Llama Stack allows you to build different layers of distributions for your AI workloads using various SDKs and API providers.
**Infrastructure Complexity**
- Running large language models efficiently requires specialized infrastructure.
- Different deployment scenarios (local development, cloud, edge) need different solutions.
- Moving from development to production often requires significant rework.
**Essential Capabilities**
- Safety guardrails and content filtering are necessary in an enterprise setting.
- Just model inference is not enough - Knowledge retrieval and RAG capabilities are required.
- Nearly any application needs composable multi-step workflows.
- Finally, without monitoring, observability and evaluation, you end up operating in the dark.
**Lack of Flexibility and Choice**
- Directly integrating with multiple providers creates tight coupling.
- Different providers have different APIs and abstractions.
- Changing providers requires significant code changes.
### Our Solution: A Universal Stack
```{image} ../../_static/llama-stack.png ```{image} ../../_static/llama-stack.png
:alt: Llama Stack :alt: Llama Stack
:width: 400px :width: 400px
``` ```
### Benefits of Llama stack
#### Current challenges in custom AI applications
Building production AI applications today requires solving multiple challenges:
**Infrastructure Complexity**
- Running large language models efficiently requires specialized infrastructure.
- Different deployment scenarios (local development, cloud, edge) need different solutions.
- Moving from development to production often requires significant rework.
**Essential Capabilities**
- Safety guardrails and content filtering are necessary in an enterprise setting.
- Just model inference is not enough - Knowledge retrieval and RAG capabilities are required.
- Nearly any application needs composable multi-step workflows.
- Without monitoring, observability and evaluation, you end up operating in the dark.
**Lack of Flexibility and Choice**
- Directly integrating with multiple providers creates tight coupling.
- Different providers have different APIs and abstractions.
- Changing providers requires significant code changes.
#### Our Solution: A Universal Stack
Llama Stack addresses these challenges through a service-oriented, API-first approach: Llama Stack addresses these challenges through a service-oriented, API-first approach:
**Develop Anywhere, Deploy Everywhere** **Develop Anywhere, Deploy Everywhere**

View file

@ -2,6 +2,10 @@
Given Llama Stack's service-oriented philosophy, a few concepts and workflows arise which may not feel completely natural in the LLM landscape, especially if you are coming with a background in other frameworks. Given Llama Stack's service-oriented philosophy, a few concepts and workflows arise which may not feel completely natural in the LLM landscape, especially if you are coming with a background in other frameworks.
```{include} architecture.md
:start-after: ## Llama Stack architecture
```
```{include} apis.md ```{include} apis.md
:start-after: ## APIs :start-after: ## APIs
``` ```
@ -10,14 +14,10 @@ Given Llama Stack's service-oriented philosophy, a few concepts and workflows ar
:start-after: ## API Providers :start-after: ## API Providers
``` ```
```{include} resources.md
:start-after: ## Resources
```
```{include} distributions.md ```{include} distributions.md
:start-after: ## Distributions :start-after: ## Distributions
``` ```
```{include} evaluation_concepts.md ```{include} resources.md
:start-after: ## Evaluation Concepts :start-after: ## Resources
``` ```

View file

@ -52,7 +52,18 @@ extensions = [
"sphinxcontrib.redoc", "sphinxcontrib.redoc",
"sphinxcontrib.mermaid", "sphinxcontrib.mermaid",
"sphinxcontrib.video", "sphinxcontrib.video",
"sphinx_reredirects"
] ]
redirects = {
"providers/post_training/index": "../../advanced_apis/post_training/index.html",
"providers/eval/index": "../../advanced_apis/eval/index.html",
"providers/scoring/index": "../../advanced_apis/scoring/index.html",
"playground/index": "../../building_applications/playground/index.html",
"openai/index": "../../providers/index.html#openai-api-compatibility",
"introduction/index": "../concepts/index.html#llama-stack-architecture"
}
myst_enable_extensions = ["colon_fence"] myst_enable_extensions = ["colon_fence"]
html_theme = "sphinx_rtd_theme" html_theme = "sphinx_rtd_theme"

View file

@ -0,0 +1,4 @@
# Deployment Examples
```{include} kubernetes_deployment.md
```

View file

@ -1,4 +1,4 @@
# Kubernetes Deployment Guide ## Kubernetes Deployment Guide
Instead of starting the Llama Stack and vLLM servers locally. We can deploy them in a Kubernetes cluster. Instead of starting the Llama Stack and vLLM servers locally. We can deploy them in a Kubernetes cluster.
@ -222,10 +222,21 @@ llama-stack-client --endpoint http://localhost:5000 inference chat-completion --
## Deploying Llama Stack Server in AWS EKS ## Deploying Llama Stack Server in AWS EKS
We've also provided a script to deploy the Llama Stack server in an AWS EKS cluster. Once you have an [EKS cluster](https://docs.aws.amazon.com/eks/latest/userguide/getting-started.html), you can run the following script to deploy the Llama Stack server. We've also provided a script to deploy the Llama Stack server in an AWS EKS cluster.
Prerequisites:
- Set up an [EKS cluster](https://docs.aws.amazon.com/eks/latest/userguide/getting-started.html).
- Create a [Github OAuth app](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and get the client ID and client secret.
- Set the `Authorization callback URL` to `http://<your-llama-stack-ui-url>/api/auth/callback/`
Run the following script to deploy the Llama Stack server:
``` ```
export HF_TOKEN=<your-huggingface-token>
export GITHUB_CLIENT_ID=<your-github-client-id>
export GITHUB_CLIENT_SECRET=<your-github-client-secret>
export LLAMA_STACK_UI_URL=<your-llama-stack-ui-url>
cd docs/source/distributions/eks cd docs/source/distributions/eks
./apply.sh ./apply.sh
``` ```

View file

@ -6,14 +6,9 @@ This section provides an overview of the distributions available in Llama Stack.
```{toctree} ```{toctree}
:maxdepth: 3 :maxdepth: 3
list_of_distributions
building_distro
customizing_run_yaml
importing_as_library importing_as_library
configuration configuration
customizing_run_yaml
list_of_distributions
kubernetes_deployment
building_distro
on_device_distro
remote_hosted_distro
self_hosted_distro
``` ```

View file

@ -21,6 +21,24 @@ else
exit 1 exit 1
fi fi
if [ -z "${GITHUB_CLIENT_ID:-}" ]; then
echo "ERROR: GITHUB_CLIENT_ID not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide"
exit 1
fi
if [ -z "${GITHUB_CLIENT_SECRET:-}" ]; then
echo "ERROR: GITHUB_CLIENT_SECRET not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide"
exit 1
fi
if [ -z "${LLAMA_STACK_UI_URL:-}" ]; then
echo "ERROR: LLAMA_STACK_UI_URL not set. Should be set to the external URL of the UI (excluding port). You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide"
exit 1
fi
set -euo pipefail set -euo pipefail
set -x set -x

View file

@ -122,6 +122,9 @@ data:
provider_id: rag-runtime provider_id: rag-runtime
server: server:
port: 8321 port: 8321
auth:
provider_config:
type: github_token
kind: ConfigMap kind: ConfigMap
metadata: metadata:
creationTimestamp: null creationTimestamp: null

View file

@ -27,7 +27,7 @@ spec:
spec: spec:
containers: containers:
- name: llama-stack - name: llama-stack
image: llamastack/distribution-remote-vllm:latest image: llamastack/distribution-starter:latest
imagePullPolicy: Always # since we have specified latest instead of a version imagePullPolicy: Always # since we have specified latest instead of a version
env: env:
- name: ENABLE_CHROMADB - name: ENABLE_CHROMADB

View file

@ -119,3 +119,6 @@ tool_groups:
provider_id: rag-runtime provider_id: rag-runtime
server: server:
port: 8321 port: 8321
auth:
provider_config:
type: github_token

View file

@ -26,6 +26,12 @@ spec:
value: "http://llama-stack-service:8321" value: "http://llama-stack-service:8321"
- name: LLAMA_STACK_UI_PORT - name: LLAMA_STACK_UI_PORT
value: "8322" value: "8322"
- name: GITHUB_CLIENT_ID
value: "${GITHUB_CLIENT_ID}"
- name: GITHUB_CLIENT_SECRET
value: "${GITHUB_CLIENT_SECRET}"
- name: NEXTAUTH_URL
value: "${LLAMA_STACK_UI_URL}:8322"
args: args:
- -c - -c
- | - |

View file

@ -167,7 +167,7 @@ When using the `:` pattern (like `${env.OLLAMA_INFERENCE_MODEL:__disabled__}`),
## Running the Distribution ## Running the Distribution
You can run the starter distribution via Docker or Conda. You can run the starter distribution via Docker, Conda, or venv.
### Via Docker ### Via Docker
@ -186,17 +186,12 @@ docker run \
--port $LLAMA_STACK_PORT --port $LLAMA_STACK_PORT
``` ```
### Via Conda ### Via Conda or venv
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available. Ensure you have configured the starter distribution using the environment variables explained above.
```bash ```bash
llama stack build --template starter --image-type conda uv run --with llama-stack llama stack build --template starter --image-type <conda|venv> --run
llama stack run distributions/starter/run.yaml \
--port 8321 \
--env OPENAI_API_KEY=your_openai_key \
--env FIREWORKS_API_KEY=your_fireworks_key \
--env TOGETHER_API_KEY=your_together_key
``` ```
## Example Usage ## Example Usage

View file

@ -28,5 +28,4 @@ If you have built a container image and want to deploy it in a Kubernetes cluste
importing_as_library importing_as_library
configuration configuration
kubernetes_deployment
``` ```

View file

@ -1,4 +1,4 @@
# Detailed Tutorial ## Detailed Tutorial
In this guide, we'll walk through how you can use the Llama Stack (server and client SDK) to test a simple agent. In this guide, we'll walk through how you can use the Llama Stack (server and client SDK) to test a simple agent.
A Llama Stack agent is a simple integrated system that can perform tasks by combining a Llama model for reasoning with A Llama Stack agent is a simple integrated system that can perform tasks by combining a Llama model for reasoning with
@ -10,7 +10,7 @@ Llama Stack is a stateful service with REST APIs to support seamless transition
In this guide, we'll walk through how to build a RAG agent locally using Llama Stack with [Ollama](https://ollama.com/) In this guide, we'll walk through how to build a RAG agent locally using Llama Stack with [Ollama](https://ollama.com/)
as the inference [provider](../providers/index.md#inference) for a Llama Model. as the inference [provider](../providers/index.md#inference) for a Llama Model.
## Step 1: Installation and Setup ### Step 1: Installation and Setup
Install Ollama by following the instructions on the [Ollama website](https://ollama.com/download), then Install Ollama by following the instructions on the [Ollama website](https://ollama.com/download), then
download Llama 3.2 3B model, and then start the Ollama service. download Llama 3.2 3B model, and then start the Ollama service.
@ -45,7 +45,7 @@ Setup your virtual environment.
uv sync --python 3.12 uv sync --python 3.12
source .venv/bin/activate source .venv/bin/activate
``` ```
## Step 2: Run Llama Stack ### Step 2: Run Llama Stack
Llama Stack is a server that exposes multiple APIs, you connect with it using the Llama Stack client SDK. Llama Stack is a server that exposes multiple APIs, you connect with it using the Llama Stack client SDK.
::::{tab-set} ::::{tab-set}
@ -132,7 +132,7 @@ Now you can use the Llama Stack client to run inference and build agents!
You can reuse the server setup or use the [Llama Stack Client](https://github.com/meta-llama/llama-stack-client-python/). You can reuse the server setup or use the [Llama Stack Client](https://github.com/meta-llama/llama-stack-client-python/).
Note that the client package is already included in the `llama-stack` package. Note that the client package is already included in the `llama-stack` package.
## Step 3: Run Client CLI ### Step 3: Run Client CLI
Open a new terminal and navigate to the same directory you started the server from. Then set up a new or activate your Open a new terminal and navigate to the same directory you started the server from. Then set up a new or activate your
existing server virtual environment. existing server virtual environment.
@ -232,7 +232,7 @@ OpenAIChatCompletion(
) )
``` ```
## Step 4: Run the Demos ### Step 4: Run the Demos
Note that these demos show the [Python Client SDK](../references/python_sdk_reference/index.md). Note that these demos show the [Python Client SDK](../references/python_sdk_reference/index.md).
Other SDKs are also available, please refer to the [Client SDK](../index.md#client-sdks) list for the complete options. Other SDKs are also available, please refer to the [Client SDK](../index.md#client-sdks) list for the complete options.
@ -242,7 +242,7 @@ Other SDKs are also available, please refer to the [Client SDK](../index.md#clie
:::{tab-item} Basic Inference :::{tab-item} Basic Inference
Now you can run inference using the Llama Stack client SDK. Now you can run inference using the Llama Stack client SDK.
### i. Create the Script #### i. Create the Script
Create a file `inference.py` and add the following code: Create a file `inference.py` and add the following code:
```python ```python
@ -269,7 +269,7 @@ response = client.chat.completions.create(
print(response) print(response)
``` ```
### ii. Run the Script #### ii. Run the Script
Let's run the script using `uv` Let's run the script using `uv`
```bash ```bash
uv run python inference.py uv run python inference.py
@ -283,7 +283,7 @@ OpenAIChatCompletion(id='chatcmpl-30cd0f28-a2ad-4b6d-934b-13707fc60ebf', choices
:::{tab-item} Build a Simple Agent :::{tab-item} Build a Simple Agent
Next we can move beyond simple inference and build an agent that can perform tasks using the Llama Stack server. Next we can move beyond simple inference and build an agent that can perform tasks using the Llama Stack server.
### i. Create the Script #### i. Create the Script
Create a file `agent.py` and add the following code: Create a file `agent.py` and add the following code:
```python ```python
@ -455,7 +455,7 @@ uv run python agent.py
For our last demo, we can build a RAG agent that can answer questions about the Torchtune project using the documents For our last demo, we can build a RAG agent that can answer questions about the Torchtune project using the documents
in a vector database. in a vector database.
### i. Create the Script #### i. Create the Script
Create a file `rag_agent.py` and add the following code: Create a file `rag_agent.py` and add the following code:
```python ```python
@ -533,7 +533,7 @@ for t in turns:
for event in AgentEventLogger().log(stream): for event in AgentEventLogger().log(stream):
event.print() event.print()
``` ```
### ii. Run the Script #### ii. Run the Script
Let's run the script using `uv` Let's run the script using `uv`
```bash ```bash
uv run python rag_agent.py uv run python rag_agent.py

View file

@ -1,123 +1,13 @@
# Quickstart # Getting Started
Get started with Llama Stack in minutes! ```{include} quickstart.md
:start-after: ## Quickstart
Llama Stack is a stateful service with REST APIs to support the seamless transition of AI applications across different
environments. You can build and test using a local server first and deploy to a hosted endpoint for production.
In this guide, we'll walk through how to build a RAG application locally using Llama Stack with [Ollama](https://ollama.com/)
as the inference [provider](../providers/inference/index) for a Llama Model.
**💡 Notebook Version:** You can also follow this quickstart guide in a Jupyter notebook format: [quick_start.ipynb](https://github.com/meta-llama/llama-stack/blob/main/docs/quick_start.ipynb)
#### Step 1: Install and setup
1. Install [uv](https://docs.astral.sh/uv/)
2. Run inference on a Llama model with [Ollama](https://ollama.com/download)
```bash
ollama run llama3.2:3b --keepalive 60m
``` ```
#### Step 2: Run the Llama Stack server
We will use `uv` to run the Llama Stack server. ```{include} libraries.md
```bash :start-after: ## Libraries (SDKs)
INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
``` ```
#### Step 3: Run the demo
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
```python ```{include} detailed_tutorial.md
from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient :start-after: ## Detailed Tutorial
vector_db_id = "my_demo_vector_db"
client = LlamaStackClient(base_url="http://localhost:8321")
models = client.models.list()
# Select the first LLM and first embedding models
model_id = next(m for m in models if m.model_type == "llm").identifier
embedding_model_id = (
em := next(m for m in models if m.model_type == "embedding")
).identifier
embedding_dimension = em.metadata["embedding_dimension"]
_ = client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
provider_id="faiss",
)
source = "https://www.paulgraham.com/greatwork.html"
print("rag_tool> Ingesting document:", source)
document = RAGDocument(
document_id="document_1",
content=source,
mime_type="text/html",
metadata={},
)
client.tool_runtime.rag_tool.insert(
documents=[document],
vector_db_id=vector_db_id,
chunk_size_in_tokens=50,
)
agent = Agent(
client,
model=model_id,
instructions="You are a helpful assistant",
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": [vector_db_id]},
}
],
)
prompt = "How do you do great work?"
print("prompt>", prompt)
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=agent.create_session("rag_session"),
stream=True,
)
for log in AgentEventLogger().log(response):
log.print()
``` ```
We will use `uv` to run the script
```
uv run --with llama-stack-client,fire,requests demo_script.py
```
And you should see output like below.
```
rag_tool> Ingesting document: https://www.paulgraham.com/greatwork.html
prompt> How do you do great work?
inference> [knowledge_search(query="What is the key to doing great work")]
tool_execution> Tool:knowledge_search Args:{'query': 'What is the key to doing great work'}
tool_execution> Tool:knowledge_search Response:[TextContentItem(text='knowledge_search tool found 5 chunks:\nBEGIN of knowledge_search tool results.\n', type='text'), TextContentItem(text="Result 1:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 2:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 3:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 4:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 5:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text='END of knowledge_search tool results.\n', type='text')]
inference> Based on the search results, it seems that doing great work means doing something important so well that you expand people's ideas of what's possible. However, there is no clear threshold for importance, and it can be difficult to judge at the time.
To further clarify, I would suggest that doing great work involves:
* Completing tasks with high quality and attention to detail
* Expanding on existing knowledge or ideas
* Making a positive impact on others through your work
* Striving for excellence and continuous improvement
Ultimately, great work is about making a meaningful contribution and leaving a lasting impression.
```
Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳
## Next Steps
Now you're ready to dive deeper into Llama Stack!
- Explore the [Detailed Tutorial](./detailed_tutorial.md).
- Try the [Getting Started Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb).
- Browse more [Notebooks on GitHub](https://github.com/meta-llama/llama-stack/tree/main/docs/notebooks).
- Learn about Llama Stack [Concepts](../concepts/index.md).
- Discover how to [Build Llama Stacks](../distributions/index.md).
- Refer to our [References](../references/index.md) for details on the Llama CLI and Python SDK.
- Check out the [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository for example applications and tutorials.

View file

@ -0,0 +1,10 @@
## Libraries (SDKs)
We have a number of client-side SDKs available for different languages.
| **Language** | **Client SDK** | **Package** |
| :----: | :----: | :----: |
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/)
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift/tree/latest-release) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift)
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client)
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)

View file

@ -0,0 +1,129 @@
## Quickstart
Get started with Llama Stack in minutes!
Llama Stack is a stateful service with REST APIs to support the seamless transition of AI applications across different
environments. You can build and test using a local server first and deploy to a hosted endpoint for production.
In this guide, we'll walk through how to build a RAG application locally using Llama Stack with [Ollama](https://ollama.com/)
as the inference [provider](../providers/inference/index) for a Llama Model.
**💡 Notebook Version:** You can also follow this quickstart guide in a Jupyter notebook format: [quick_start.ipynb](https://github.com/meta-llama/llama-stack/blob/main/docs/quick_start.ipynb)
#### Step 1: Install and setup
1. Install [uv](https://docs.astral.sh/uv/)
2. Run inference on a Llama model with [Ollama](https://ollama.com/download)
```bash
ollama run llama3.2:3b --keepalive 60m
```
#### Step 2: Run the Llama Stack server
We will use `uv` to run the Llama Stack server.
```bash
ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
```
#### Step 3: Run the demo
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
```python
from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient
vector_db_id = "my_demo_vector_db"
client = LlamaStackClient(base_url="http://localhost:8321")
models = client.models.list()
# Select the first LLM and first embedding models
model_id = next(m for m in models if m.model_type == "llm").identifier
embedding_model_id = (
em := next(m for m in models if m.model_type == "embedding")
).identifier
embedding_dimension = em.metadata["embedding_dimension"]
_ = client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
provider_id="faiss",
)
source = "https://www.paulgraham.com/greatwork.html"
print("rag_tool> Ingesting document:", source)
document = RAGDocument(
document_id="document_1",
content=source,
mime_type="text/html",
metadata={},
)
client.tool_runtime.rag_tool.insert(
documents=[document],
vector_db_id=vector_db_id,
chunk_size_in_tokens=50,
)
agent = Agent(
client,
model=model_id,
instructions="You are a helpful assistant",
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": [vector_db_id]},
}
],
)
prompt = "How do you do great work?"
print("prompt>", prompt)
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=agent.create_session("rag_session"),
stream=True,
)
for log in AgentEventLogger().log(response):
log.print()
```
We will use `uv` to run the script
```
uv run --with llama-stack-client,fire,requests demo_script.py
```
And you should see output like below.
```
rag_tool> Ingesting document: https://www.paulgraham.com/greatwork.html
prompt> How do you do great work?
inference> [knowledge_search(query="What is the key to doing great work")]
tool_execution> Tool:knowledge_search Args:{'query': 'What is the key to doing great work'}
tool_execution> Tool:knowledge_search Response:[TextContentItem(text='knowledge_search tool found 5 chunks:\nBEGIN of knowledge_search tool results.\n', type='text'), TextContentItem(text="Result 1:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 2:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 3:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 4:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 5:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text='END of knowledge_search tool results.\n', type='text')]
inference> Based on the search results, it seems that doing great work means doing something important so well that you expand people's ideas of what's possible. However, there is no clear threshold for importance, and it can be difficult to judge at the time.
To further clarify, I would suggest that doing great work involves:
* Completing tasks with high quality and attention to detail
* Expanding on existing knowledge or ideas
* Making a positive impact on others through your work
* Striving for excellence and continuous improvement
Ultimately, great work is about making a meaningful contribution and leaving a lasting impression.
```
Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳
```{admonition} HuggingFace access
:class: tip
If you are getting a **401 Client Error** from HuggingFace for the **all-MiniLM-L6-v2** model, try setting **HF_TOKEN** to a valid HuggingFace token in your environment
```
### Next Steps
Now you're ready to dive deeper into Llama Stack!
- Explore the [Detailed Tutorial](./detailed_tutorial.md).
- Try the [Getting Started Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb).
- Browse more [Notebooks on GitHub](https://github.com/meta-llama/llama-stack/tree/main/docs/notebooks).
- Learn about Llama Stack [Concepts](../concepts/index.md).
- Discover how to [Build Llama Stacks](../distributions/index.md).
- Refer to our [References](../references/index.md) for details on the Llama CLI and Python SDK.
- Check out the [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository for example applications and tutorials.

View file

@ -40,17 +40,6 @@ Kotlin.
- Ready to build? Check out the [Quick Start](getting_started/index) to get started. - Ready to build? Check out the [Quick Start](getting_started/index) to get started.
- Want to contribute? See the [Contributing](contributing/index) guide. - Want to contribute? See the [Contributing](contributing/index) guide.
## Client SDKs
We have a number of client-side SDKs available for different languages.
| **Language** | **Client SDK** | **Package** |
| :----: | :----: | :----: |
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/)
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift/tree/latest-release) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift)
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client)
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)
## Supported Llama Stack Implementations ## Supported Llama Stack Implementations
A number of "adapters" are available for some popular Inference and Vector Store providers. For other APIs (particularly Safety and Agents), we provide *reference implementations* you can use to get started. We expect this list to grow over time. We are slowly onboarding more providers to the ecosystem as we get more confidence in the APIs. A number of "adapters" are available for some popular Inference and Vector Store providers. For other APIs (particularly Safety and Agents), we provide *reference implementations* you can use to get started. We expect this list to grow over time. We are slowly onboarding more providers to the ecosystem as we get more confidence in the APIs.
@ -133,14 +122,12 @@ A number of "adapters" are available for some popular Inference and Vector Store
self self
getting_started/index getting_started/index
getting_started/detailed_tutorial
introduction/index
concepts/index concepts/index
openai/index
providers/index providers/index
distributions/index distributions/index
advanced_apis/index
building_applications/index building_applications/index
playground/index deploying/index
contributing/index contributing/index
references/index references/index
``` ```

View file

@ -1,4 +1,4 @@
# Providers Overview # API Providers Overview
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include: The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Meta Reference, Ollama, Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, OpenAI, Anthropic, Gemini, WatsonX, etc.), - LLM inference providers (e.g., Meta Reference, Ollama, Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, OpenAI, Anthropic, Gemini, WatsonX, etc.),
@ -13,13 +13,25 @@ Providers come in two flavors:
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally. Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
## External Providers ## External Providers
Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently. Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently.
```{toctree} ```{toctree}
:maxdepth: 1 :maxdepth: 1
external external.md
```
```{include} openai.md
:start-after: ## OpenAI API Compatibility
```
## Inference
Runs inference with an LLM.
```{toctree}
:maxdepth: 1
inference/index
``` ```
## Agents ## Agents
@ -40,33 +52,6 @@ Interfaces with datasets and data loaders.
datasetio/index datasetio/index
``` ```
## Eval
Generates outputs (via Inference or Agents) and perform scoring.
```{toctree}
:maxdepth: 1
eval/index
```
## Inference
Runs inference with an LLM.
```{toctree}
:maxdepth: 1
inference/index
```
## Post Training
Fine-tunes a model.
```{toctree}
:maxdepth: 1
post_training/index
```
## Safety ## Safety
Applies safety policies to the output at a Systems (not only model) level. Applies safety policies to the output at a Systems (not only model) level.
@ -76,15 +61,6 @@ Applies safety policies to the output at a Systems (not only model) level.
safety/index safety/index
``` ```
## Scoring
Evaluates the outputs of the system.
```{toctree}
:maxdepth: 1
scoring/index
```
## Telemetry ## Telemetry
Collects telemetry data from the system. Collects telemetry data from the system.
@ -94,15 +70,6 @@ Collects telemetry data from the system.
telemetry/index telemetry/index
``` ```
## Tool Runtime
Is associated with the ToolGroup resouces.
```{toctree}
:maxdepth: 1
tool_runtime/index
```
## Vector IO ## Vector IO
Vector IO refers to operations on vector databases, such as adding documents, searching, and deleting documents. Vector IO refers to operations on vector databases, such as adding documents, searching, and deleting documents.
@ -114,3 +81,12 @@ io and database are used to store and retrieve documents for retrieval.
vector_io/index vector_io/index
``` ```
## Tool Runtime
Is associated with the ToolGroup resources.
```{toctree}
:maxdepth: 1
tool_runtime/index
```

View file

@ -4,7 +4,6 @@ This section contains documentation for all available providers for the **infere
- [inline::meta-reference](inline_meta-reference.md) - [inline::meta-reference](inline_meta-reference.md)
- [inline::sentence-transformers](inline_sentence-transformers.md) - [inline::sentence-transformers](inline_sentence-transformers.md)
- [inline::vllm](inline_vllm.md)
- [remote::anthropic](remote_anthropic.md) - [remote::anthropic](remote_anthropic.md)
- [remote::bedrock](remote_bedrock.md) - [remote::bedrock](remote_bedrock.md)
- [remote::cerebras](remote_cerebras.md) - [remote::cerebras](remote_cerebras.md)

View file

@ -1,29 +0,0 @@
# inline::vllm
## Description
vLLM inference provider for high-performance model serving with PagedAttention and continuous batching.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `tensor_parallel_size` | `<class 'int'>` | No | 1 | Number of tensor parallel replicas (number of GPUs to use). |
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
| `max_model_len` | `<class 'int'>` | No | 4096 | Maximum context length to use during serving. |
| `max_num_seqs` | `<class 'int'>` | No | 4 | Maximum parallel batch size for generation. |
| `enforce_eager` | `<class 'bool'>` | No | False | Whether to use eager mode for inference (otherwise cuda graphs are used). |
| `gpu_memory_utilization` | `<class 'float'>` | No | 0.3 | How much GPU memory will be allocated when this provider has finished loading, including memory that was already allocated before loading. |
## Sample Configuration
```yaml
tensor_parallel_size: ${env.TENSOR_PARALLEL_SIZE:=1}
max_tokens: ${env.MAX_TOKENS:=4096}
max_model_len: ${env.MAX_MODEL_LEN:=4096}
max_num_seqs: ${env.MAX_NUM_SEQS:=4}
enforce_eager: ${env.ENFORCE_EAGER:=False}
gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:=0.3}
```

View file

@ -9,6 +9,8 @@ Ollama inference provider for running local models through the Ollama runtime.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `url` | `<class 'str'>` | No | http://localhost:11434 | | | `url` | `<class 'str'>` | No | http://localhost:11434 | |
| `refresh_models` | `<class 'bool'>` | No | False | refresh and re-register models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | interval in seconds to refresh models |
## Sample Configuration ## Sample Configuration

View file

@ -12,11 +12,13 @@ Remote vLLM inference provider for connecting to vLLM servers.
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. | | `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
| `api_token` | `str \| None` | No | fake | The API token | | `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | Interval in seconds to refresh models |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.VLLM_URL} url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}

View file

@ -1,14 +1,14 @@
# OpenAI API Compatibility ## OpenAI API Compatibility
## Server path ### Server path
Llama Stack exposes an OpenAI-compatible API endpoint at `/v1/openai/v1`. So, for a Llama Stack server running locally on port `8321`, the full url to the OpenAI-compatible API endpoint is `http://localhost:8321/v1/openai/v1`. Llama Stack exposes an OpenAI-compatible API endpoint at `/v1/openai/v1`. So, for a Llama Stack server running locally on port `8321`, the full url to the OpenAI-compatible API endpoint is `http://localhost:8321/v1/openai/v1`.
## Clients ### Clients
You should be able to use any client that speaks OpenAI APIs with Llama Stack. We regularly test with the official Llama Stack clients as well as OpenAI's official Python client. You should be able to use any client that speaks OpenAI APIs with Llama Stack. We regularly test with the official Llama Stack clients as well as OpenAI's official Python client.
### Llama Stack Client #### Llama Stack Client
When using the Llama Stack client, set the `base_url` to the root of your Llama Stack server. It will automatically route OpenAI-compatible requests to the right server endpoint for you. When using the Llama Stack client, set the `base_url` to the root of your Llama Stack server. It will automatically route OpenAI-compatible requests to the right server endpoint for you.
@ -18,7 +18,7 @@ from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url="http://localhost:8321") client = LlamaStackClient(base_url="http://localhost:8321")
``` ```
### OpenAI Client #### OpenAI Client
When using an OpenAI client, set the `base_url` to the `/v1/openai/v1` path on your Llama Stack server. When using an OpenAI client, set the `base_url` to the `/v1/openai/v1` path on your Llama Stack server.
@ -30,9 +30,9 @@ client = OpenAI(base_url="http://localhost:8321/v1/openai/v1", api_key="none")
Regardless of the client you choose, the following code examples should all work the same. Regardless of the client you choose, the following code examples should all work the same.
## APIs implemented ### APIs implemented
### Models #### Models
Many of the APIs require you to pass in a model parameter. To see the list of models available in your Llama Stack server: Many of the APIs require you to pass in a model parameter. To see the list of models available in your Llama Stack server:
@ -40,13 +40,13 @@ Many of the APIs require you to pass in a model parameter. To see the list of mo
models = client.models.list() models = client.models.list()
``` ```
### Responses #### Responses
:::{note} :::{note}
The Responses API implementation is still in active development. While it is quite usable, there are still unimplemented parts of the API. We'd love feedback on any use-cases you try that do not work to help prioritize the pieces left to implement. Please open issues in the [meta-llama/llama-stack](https://github.com/meta-llama/llama-stack) GitHub repository with details of anything that does not work. The Responses API implementation is still in active development. While it is quite usable, there are still unimplemented parts of the API. We'd love feedback on any use-cases you try that do not work to help prioritize the pieces left to implement. Please open issues in the [meta-llama/llama-stack](https://github.com/meta-llama/llama-stack) GitHub repository with details of anything that does not work.
::: :::
#### Simple inference ##### Simple inference
Request: Request:
@ -66,7 +66,7 @@ Syntax whispers secrets sweet
Code's gentle silence Code's gentle silence
``` ```
#### Structured Output ##### Structured Output
Request: Request:
@ -106,9 +106,9 @@ Example output:
{ "participants": ["Alice", "Bob"] } { "participants": ["Alice", "Bob"] }
``` ```
### Chat Completions #### Chat Completions
#### Simple inference ##### Simple inference
Request: Request:
@ -129,7 +129,7 @@ Logic flows like a river
Code's gentle beauty Code's gentle beauty
``` ```
#### Structured Output ##### Structured Output
Request: Request:
@ -170,9 +170,9 @@ Example output:
{ "participants": ["Alice", "Bob"] } { "participants": ["Alice", "Bob"] }
``` ```
### Completions #### Completions
#### Simple inference ##### Simple inference
Request: Request:

View file

@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server | | `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server | | `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server | | `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) | | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. | | `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. > **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
```yaml ```yaml
uri: ${env.MILVUS_ENDPOINT} uri: ${env.MILVUS_ENDPOINT}
token: ${env.MILVUS_TOKEN} token: ${env.MILVUS_TOKEN}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db
``` ```

View file

@ -40,6 +40,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
| `db` | `str \| None` | No | postgres | | | `db` | `str \| None` | No | postgres | |
| `user` | `str \| None` | No | postgres | | | `user` | `str \| None` | No | postgres | |
| `password` | `str \| None` | No | mysecretpassword | | | `password` | `str \| None` | No | mysecretpassword | |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
## Sample Configuration ## Sample Configuration
@ -49,6 +50,9 @@ port: ${env.PGVECTOR_PORT:=5432}
db: ${env.PGVECTOR_DB} db: ${env.PGVECTOR_DB}
user: ${env.PGVECTOR_USER} user: ${env.PGVECTOR_USER}
password: ${env.PGVECTOR_PASSWORD} password: ${env.PGVECTOR_PASSWORD}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/pgvector_registry.db
``` ```

View file

@ -36,7 +36,9 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
## Sample Configuration ## Sample Configuration
```yaml ```yaml
{} kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/weaviate_registry.db
``` ```

View file

@ -9,7 +9,8 @@ The `llama-stack-client` CLI allows you to query information about the distribut
llama-stack-client llama-stack-client
Usage: llama-stack-client [OPTIONS] COMMAND [ARGS]... Usage: llama-stack-client [OPTIONS] COMMAND [ARGS]...
Welcome to the LlamaStackClient CLI Welcome to the llama-stack-client CLI - a command-line interface for
interacting with Llama Stack
Options: Options:
--version Show the version and exit. --version Show the version and exit.
@ -35,6 +36,7 @@ Commands:
``` ```
### `llama-stack-client configure` ### `llama-stack-client configure`
Configure Llama Stack Client CLI.
```bash ```bash
llama-stack-client configure llama-stack-client configure
> Enter the host name of the Llama Stack distribution server: localhost > Enter the host name of the Llama Stack distribution server: localhost
@ -42,7 +44,24 @@ llama-stack-client configure
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321 Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
``` ```
Optional arguments:
- `--endpoint`: Llama Stack distribution endpoint
- `--api-key`: Llama Stack distribution API key
## `llama-stack-client inspect version`
Inspect server configuration.
```bash
llama-stack-client inspect version
```
```bash
VersionInfo(version='0.2.14')
```
### `llama-stack-client providers list` ### `llama-stack-client providers list`
Show available providers on distribution endpoint
```bash ```bash
llama-stack-client providers list llama-stack-client providers list
``` ```
@ -66,9 +85,74 @@ llama-stack-client providers list
+-----------+----------------+-----------------+ +-----------+----------------+-----------------+
``` ```
### `llama-stack-client providers inspect`
Show specific provider configuration on distribution endpoint
```bash
llama-stack-client providers inspect <provider_id>
```
## Inference
Inference (chat).
### `llama-stack-client inference chat-completion`
Show available inference chat completion endpoints on distribution endpoint
```bash
llama-stack-client inference chat-completion --message <message> [--stream] [--session] [--model-id]
```
```bash
OpenAIChatCompletion(
id='chatcmpl-aacd11f3-8899-4ec5-ac5b-e655132f6891',
choices=[
OpenAIChatCompletionChoice(
finish_reason='stop',
index=0,
message=OpenAIChatCompletionChoiceMessageOpenAIAssistantMessageParam(
role='assistant',
content='The captain of the whaleship Pequod in Nathaniel Hawthorne\'s novel "Moby-Dick" is Captain
Ahab. He\'s a vengeful and obsessive old sailor who\'s determined to hunt down and kill the white sperm whale
Moby-Dick, whom he\'s lost his leg to in a previous encounter.',
name=None,
tool_calls=None,
refusal=None,
annotations=None,
audio=None,
function_call=None
),
logprobs=None
)
],
created=1752578797,
model='llama3.2:3b-instruct-fp16',
object='chat.completion',
service_tier=None,
system_fingerprint='fp_ollama',
usage={
'completion_tokens': 67,
'prompt_tokens': 33,
'total_tokens': 100,
'completion_tokens_details': None,
'prompt_tokens_details': None
}
)
```
Required arguments:
**Note:** At least one of these parameters is required for chat completion
- `--message`: Message
- `--session`: Start a Chat Session
Optional arguments:
- `--stream`: Stream
- `--model-id`: Model ID
## Model Management ## Model Management
Manage GenAI models.
### `llama-stack-client models list` ### `llama-stack-client models list`
Show available llama models at distribution endpoint
```bash ```bash
llama-stack-client models list llama-stack-client models list
``` ```
@ -85,6 +169,7 @@ Total models: 1
``` ```
### `llama-stack-client models get` ### `llama-stack-client models get`
Show details of a specific model at the distribution endpoint
```bash ```bash
llama-stack-client models get Llama3.1-8B-Instruct llama-stack-client models get Llama3.1-8B-Instruct
``` ```
@ -105,69 +190,92 @@ Model RandomModel is not found at distribution endpoint host:port. Please ensure
``` ```
### `llama-stack-client models register` ### `llama-stack-client models register`
Register a new model at distribution endpoint
```bash ```bash
llama-stack-client models register <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>] llama-stack-client models register <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>] [--model-type <model_type>]
``` ```
### `llama-stack-client models update` Required arguments:
- `MODEL_ID`: Model ID
- `--provider-id`: Provider ID for the model
Optional arguments:
- `--provider-model-id`: Provider's model ID
- `--metadata`: JSON metadata for the model
- `--model-type`: Model type: `llm`, `embedding`
### `llama-stack-client models unregister`
Unregister a model from distribution endpoint
```bash ```bash
llama-stack-client models update <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>] llama-stack-client models unregister <model_id>
```
### `llama-stack-client models delete`
```bash
llama-stack-client models delete <model_id>
``` ```
## Vector DB Management ## Vector DB Management
Manage vector databases.
### `llama-stack-client vector_dbs list` ### `llama-stack-client vector_dbs list`
Show available vector dbs on distribution endpoint
```bash ```bash
llama-stack-client vector_dbs list llama-stack-client vector_dbs list
``` ```
``` ```
+--------------+----------------+---------------------+---------------+------------------------+ ┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
| identifier | provider_id | provider_resource_id| vector_db_type| params | ┃ identifier ┃ provider_id ┃ provider_resource_id ┃ vector_db_type ┃ params ┃
+==============+================+=====================+===============+========================+ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
| test_bank | meta-reference | test_bank | vector | embedding_model: all-MiniLM-L6-v2 │ my_demo_vector_db │ faiss │ my_demo_vector_db │ │ embedding_dimension: 384 │
embedding_dimension: 384| │ │ │ │ │ embedding_model: all-MiniLM-L6-v2 │
+--------------+----------------+---------------------+---------------+------------------------+ │ │ │ │ │ type: vector_db │
│ │ │ │ │ │
└──────────────────────────┴─────────────┴──────────────────────────┴────────────────┴───────────────────────────────────┘
``` ```
### `llama-stack-client vector_dbs register` ### `llama-stack-client vector_dbs register`
Create a new vector db
```bash ```bash
llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>] llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>]
``` ```
Required arguments:
- `VECTOR_DB_ID`: Vector DB ID
Optional arguments: Optional arguments:
- `--provider-id`: Provider ID for the vector db - `--provider-id`: Provider ID for the vector db
- `--provider-vector-db-id`: Provider's vector db ID - `--provider-vector-db-id`: Provider's vector db ID
- `--embedding-model`: Embedding model to use. Default: "all-MiniLM-L6-v2" - `--embedding-model`: Embedding model to use. Default: `all-MiniLM-L6-v2`
- `--embedding-dimension`: Dimension of embeddings. Default: 384 - `--embedding-dimension`: Dimension of embeddings. Default: 384
### `llama-stack-client vector_dbs unregister` ### `llama-stack-client vector_dbs unregister`
Delete a vector db
```bash ```bash
llama-stack-client vector_dbs unregister <vector-db-id> llama-stack-client vector_dbs unregister <vector-db-id>
``` ```
Required arguments:
- `VECTOR_DB_ID`: Vector DB ID
## Shield Management ## Shield Management
Manage safety shield services.
### `llama-stack-client shields list` ### `llama-stack-client shields list`
Show available safety shields on distribution endpoint
```bash ```bash
llama-stack-client shields list llama-stack-client shields list
``` ```
``` ```
+--------------+----------+----------------+-------------+ ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
| identifier | params | provider_id | type | ┃ identifier ┃ provider_alias ┃ params ┃ provider_id ┃
+==============+==========+================+=============+ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
| llama_guard | {} | meta-reference | llama_guard | │ ollama │ ollama/llama-guard3:1b │ │ llama-guard │
+--------------+----------+----------------+-------------+ └──────────────────────────────────┴───────────────────────────────────────────────────────────────────────┴───────────────────────┴────────────────────────────────────┘
``` ```
### `llama-stack-client shields register` ### `llama-stack-client shields register`
Register a new safety shield
```bash ```bash
llama-stack-client shields register --shield-id <shield-id> [--provider-id <provider-id>] [--provider-shield-id <provider-shield-id>] [--params <params>] llama-stack-client shields register --shield-id <shield-id> [--provider-id <provider-id>] [--provider-shield-id <provider-shield-id>] [--params <params>]
``` ```
@ -180,41 +288,29 @@ Optional arguments:
- `--provider-shield-id`: Provider's shield ID - `--provider-shield-id`: Provider's shield ID
- `--params`: JSON configuration parameters for the shield - `--params`: JSON configuration parameters for the shield
## Eval Task Management
### `llama-stack-client benchmarks list`
```bash
llama-stack-client benchmarks list
```
### `llama-stack-client benchmarks register`
```bash
llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
```
Required arguments:
- `--eval-task-id`: ID of the eval task
- `--dataset-id`: ID of the dataset to evaluate
- `--scoring-functions`: One or more scoring functions to use for evaluation
Optional arguments:
- `--provider-id`: Provider ID for the eval task
- `--provider-eval-task-id`: Provider's eval task ID
- `--metadata`: Metadata for the eval task in JSON format
## Eval execution ## Eval execution
Run evaluation tasks.
### `llama-stack-client eval run-benchmark` ### `llama-stack-client eval run-benchmark`
Run a evaluation benchmark task
```bash ```bash
llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize] llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> --model-id <model-id> [--num-examples <num>] [--visualize] [--repeat-penalty <repeat-penalty>] [--top-p <top-p>] [--max-tokens <max-tokens>]
``` ```
Required arguments: Required arguments:
- `--eval-task-config`: Path to the eval task config file in JSON format - `--eval-task-config`: Path to the eval task config file in JSON format
- `--output-dir`: Path to the directory where evaluation results will be saved - `--output-dir`: Path to the directory where evaluation results will be saved
- `--model-id`: model id to run the benchmark eval on
Optional arguments: Optional arguments:
- `--num-examples`: Number of examples to evaluate (useful for debugging) - `--num-examples`: Number of examples to evaluate (useful for debugging)
- `--visualize`: If set, visualizes evaluation results after completion - `--visualize`: If set, visualizes evaluation results after completion
- `--repeat-penalty`: repeat-penalty in the sampling params to run generation
- `--top-p`: top-p in the sampling params to run generation
- `--max-tokens`: max-tokens in the sampling params to run generation
- `--temperature`: temperature in the sampling params to run generation
Example benchmark_config.json: Example benchmark_config.json:
```json ```json
@ -231,21 +327,55 @@ Example benchmark_config.json:
``` ```
### `llama-stack-client eval run-scoring` ### `llama-stack-client eval run-scoring`
Run scoring from application datasets
```bash ```bash
llama-stack-client eval run-scoring <eval-task-id> --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize] llama-stack-client eval run-scoring <eval-task-id> --output-dir <output-dir> [--num-examples <num>] [--visualize]
``` ```
Required arguments: Required arguments:
- `--eval-task-config`: Path to the eval task config file in JSON format
- `--output-dir`: Path to the directory where scoring results will be saved - `--output-dir`: Path to the directory where scoring results will be saved
Optional arguments: Optional arguments:
- `--num-examples`: Number of examples to evaluate (useful for debugging) - `--num-examples`: Number of examples to evaluate (useful for debugging)
- `--visualize`: If set, visualizes scoring results after completion - `--visualize`: If set, visualizes scoring results after completion
- `--scoring-params-config`: Path to the scoring params config file in JSON format
- `--dataset-id`: Pre-registered dataset_id to score (from llama-stack-client datasets list)
- `--dataset-path`: Path to the dataset file to score
## Eval Tasks
Manage evaluation tasks.
### `llama-stack-client eval_tasks list`
Show available eval tasks on distribution endpoint
```bash
llama-stack-client eval_tasks list
```
### `llama-stack-client eval_tasks register`
Register a new eval task
```bash
llama-stack-client eval_tasks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <scoring-functions> [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
```
Required arguments:
- `--eval-task-id`: ID of the eval task
- `--dataset-id`: ID of the dataset to evaluate
- `--scoring-functions`: Scoring functions to use for evaluation
Optional arguments:
- `--provider-id`: Provider ID for the eval task
- `--provider-eval-task-id`: Provider's eval task ID
## Tool Group Management ## Tool Group Management
Manage available tool groups.
### `llama-stack-client toolgroups list` ### `llama-stack-client toolgroups list`
Show available llama toolgroups at distribution endpoint
```bash ```bash
llama-stack-client toolgroups list llama-stack-client toolgroups list
``` ```
@ -260,17 +390,28 @@ llama-stack-client toolgroups list
``` ```
### `llama-stack-client toolgroups get` ### `llama-stack-client toolgroups get`
Get available llama toolgroups by id
```bash ```bash
llama-stack-client toolgroups get <toolgroup_id> llama-stack-client toolgroups get <toolgroup_id>
``` ```
Shows detailed information about a specific toolgroup. If the toolgroup is not found, displays an error message. Shows detailed information about a specific toolgroup. If the toolgroup is not found, displays an error message.
Required arguments:
- `TOOLGROUP_ID`: ID of the tool group
### `llama-stack-client toolgroups register` ### `llama-stack-client toolgroups register`
Register a new toolgroup at distribution endpoint
```bash ```bash
llama-stack-client toolgroups register <toolgroup_id> [--provider-id <provider-id>] [--provider-toolgroup-id <provider-toolgroup-id>] [--mcp-config <mcp-config>] [--args <args>] llama-stack-client toolgroups register <toolgroup_id> [--provider-id <provider-id>] [--provider-toolgroup-id <provider-toolgroup-id>] [--mcp-config <mcp-config>] [--args <args>]
``` ```
Required arguments:
- `TOOLGROUP_ID`: ID of the tool group
Optional arguments: Optional arguments:
- `--provider-id`: Provider ID for the toolgroup - `--provider-id`: Provider ID for the toolgroup
- `--provider-toolgroup-id`: Provider's toolgroup ID - `--provider-toolgroup-id`: Provider's toolgroup ID
@ -278,6 +419,172 @@ Optional arguments:
- `--args`: JSON arguments for the toolgroup - `--args`: JSON arguments for the toolgroup
### `llama-stack-client toolgroups unregister` ### `llama-stack-client toolgroups unregister`
Unregister a toolgroup from distribution endpoint
```bash ```bash
llama-stack-client toolgroups unregister <toolgroup_id> llama-stack-client toolgroups unregister <toolgroup_id>
``` ```
Required arguments:
- `TOOLGROUP_ID`: ID of the tool group
## Datasets Management
Manage datasets.
### `llama-stack-client datasets list`
Show available datasets on distribution endpoint
```bash
llama-stack-client datasets list
```
### `llama-stack-client datasets register`
```bash
llama-stack-client datasets register --dataset_id <dataset_id> --purpose <purpose> [--url <url] [--dataset-path <dataset-path>] [--dataset-id <dataset-id>] [--metadata <metadata>]
```
Required arguments:
- `--dataset_id`: Id of the dataset
- `--purpose`: Purpose of the dataset
Optional arguments:
- `--metadata`: Metadata of the dataset
- `--url`: URL of the dataset
- `--dataset-path`: Local file path to the dataset. If specified, upload dataset via URL
### `llama-stack-client datasets unregister`
Remove a dataset
```bash
llama-stack-client datasets unregister <dataset-id>
```
Required arguments:
- `DATASET_ID`: Id of the dataset
## Scoring Functions Management
Manage scoring functions.
### `llama-stack-client scoring_functions list`
Show available scoring functions on distribution endpoint
```bash
llama-stack-client scoring_functions list
```
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃ identifier ┃ provider_id ┃ description ┃ type ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│ basic::bfcl │ basic │ BFCL complex scoring │ scoring_function │
│ basic::docvqa │ basic │ DocVQA Visual Question & Answer scoring function │ scoring_function │
│ basic::equality │ basic │ Returns 1.0 if the input is equal to the target, 0.0 │ scoring_function │
│ │ │ otherwise. │ │
└────────────────────────────────────────────┴──────────────┴───────────────────────────────────────────────────────────────┴──────────────────┘
```
### `llama-stack-client scoring_functions register`
Register a new scoring function
```bash
llama-stack-client scoring_functions register --scoring-fn-id <scoring-fn-id> --description <description> --return-type <return-type> [--provider-id <provider-id>] [--provider-scoring-fn-id <provider-scoring-fn-id>] [--params <params>]
```
Required arguments:
- `--scoring-fn-id`: Id of the scoring function
- `--description`: Description of the scoring function
- `--return-type`: Return type of the scoring function
Optional arguments:
- `--provider-id`: Provider ID for the scoring function
- `--provider-scoring-fn-id`: Provider's scoring function ID
- `--params`: Parameters for the scoring function in JSON format
## Post Training Management
Post-training.
### `llama-stack-client post_training list`
Show the list of available post training jobs
```bash
llama-stack-client post_training list
```
```bash
["job-1", "job-2", "job-3"]
```
### `llama-stack-client post_training artifacts`
Get the training artifacts of a specific post training job
```bash
llama-stack-client post_training artifacts --job-uuid <job-uuid>
```
```bash
JobArtifactsResponse(checkpoints=[], job_uuid='job-1')
```
Required arguments:
- `--job-uuid`: Job UUID
### `llama-stack-client post_training supervised_fine_tune`
Kick off a supervised fine tune job
```bash
llama-stack-client post_training supervised_fine_tune --job-uuid <job-uuid> --model <model> --algorithm-config <algorithm-config> --training-config <training-config> [--checkpoint-dir <checkpoint-dir>]
```
Required arguments:
- `--job-uuid`: Job UUID
- `--model`: Model ID
- `--algorithm-config`: Algorithm Config
- `--training-config`: Training Config
Optional arguments:
- `--checkpoint-dir`: Checkpoint Config
### `llama-stack-client post_training status`
Show the status of a specific post training job
```bash
llama-stack-client post_training status --job-uuid <job-uuid>
```
```bash
JobStatusResponse(
checkpoints=[],
job_uuid='job-1',
status='completed',
completed_at="",
resources_allocated="",
scheduled_at="",
started_at=""
)
```
Required arguments:
- `--job-uuid`: Job UUID
### `llama-stack-client post_training cancel`
Cancel the training job
```bash
llama-stack-client post_training cancel --job-uuid <job-uuid>
```
```bash
# This functionality is not yet implemented for llama-stack-client
╭────────────────────────────────────────────────────────────╮
│ Failed to post_training cancel_training_job │
│ │
│ Error Type: InternalServerError │
│ Details: Error code: 501 - {'detail': 'Not implemented: '} │
╰────────────────────────────────────────────────────────────╯
```
Required arguments:
- `--job-uuid`: Job UUID

View file

@ -819,6 +819,12 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol): class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ... async def get_model(self, identifier: str) -> Model: ...
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
) -> None: ...
class TextTruncation(Enum): class TextTruncation(Enum):
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.

View file

@ -7,7 +7,7 @@
from enum import StrEnum from enum import StrEnum
from typing import Any, Literal, Protocol, runtime_checkable from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, field_validator
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -36,13 +36,21 @@ class Model(CommonModelFields, Resource):
return self.identifier return self.identifier
@property @property
def provider_model_id(self) -> str | None: def provider_model_id(self) -> str:
assert self.provider_resource_id is not None, "Provider resource ID must be set"
return self.provider_resource_id return self.provider_resource_id
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
model_type: ModelType = Field(default=ModelType.llm) model_type: ModelType = Field(default=ModelType.llm)
@field_validator("provider_resource_id")
@classmethod
def validate_provider_resource_id(cls, v):
if v is None:
raise ValueError("provider_resource_id cannot be None")
return v
class ModelInput(CommonModelFields): class ModelInput(CommonModelFields):
model_id: str model_id: str

View file

@ -104,12 +104,18 @@ class RLHFAlgorithm(Enum):
dpo = "dpo" dpo = "dpo"
@json_schema_type
class DPOLossType(Enum):
sigmoid = "sigmoid"
hinge = "hinge"
ipo = "ipo"
kto_pair = "kto_pair"
@json_schema_type @json_schema_type
class DPOAlignmentConfig(BaseModel): class DPOAlignmentConfig(BaseModel):
reward_scale: float beta: float
reward_clip: float loss_type: DPOLossType = DPOLossType.sigmoid
epsilon: float
gamma: float
@json_schema_type @json_schema_type

View file

@ -19,6 +19,7 @@ class VectorDB(Resource):
embedding_model: str embedding_model: str
embedding_dimension: int embedding_dimension: int
vector_db_name: str | None = None
@property @property
def vector_db_id(self) -> str: def vector_db_id(self) -> str:
@ -70,6 +71,7 @@ class VectorDBs(Protocol):
embedding_model: str, embedding_model: str,
embedding_dimension: int | None = 384, embedding_dimension: int | None = 384,
provider_id: str | None = None, provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None, provider_vector_db_id: str | None = None,
) -> VectorDB: ) -> VectorDB:
"""Register a vector database. """Register a vector database.
@ -78,6 +80,7 @@ class VectorDBs(Protocol):
:param embedding_model: The embedding model to use. :param embedding_model: The embedding model to use.
:param embedding_dimension: The dimension of the embedding model. :param embedding_dimension: The dimension of the embedding model.
:param provider_id: The identifier of the provider. :param provider_id: The identifier of the provider.
:param vector_db_name: The name of the vector database.
:param provider_vector_db_id: The identifier of the vector database in the provider. :param provider_vector_db_id: The identifier of the vector database in the provider.
:returns: A VectorDB. :returns: A VectorDB.
""" """

View file

@ -346,7 +346,6 @@ class VectorIO(Protocol):
embedding_model: str | None = None, embedding_model: str | None = None,
embedding_dimension: int | None = 384, embedding_dimension: int | None = 384,
provider_id: str | None = None, provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject: ) -> VectorStoreObject:
"""Creates a vector store. """Creates a vector store.
@ -358,7 +357,6 @@ class VectorIO(Protocol):
:param embedding_model: The embedding model to use for this vector store. :param embedding_model: The embedding model to use for this vector store.
:param embedding_dimension: The dimension of the embedding vectors (default: 384). :param embedding_dimension: The dimension of the embedding vectors (default: 384).
:param provider_id: The ID of the provider to use for this vector store. :param provider_id: The ID of the provider to use for this vector store.
:param provider_vector_db_id: The provider-specific vector database ID.
:returns: A VectorStoreObject representing the created vector store. :returns: A VectorStoreObject representing the created vector store.
""" """
... ...

View file

@ -47,8 +47,7 @@ class StackRun(Subcommand):
self.parser.add_argument( self.parser.add_argument(
"--image-name", "--image-name",
type=str, type=str,
default=os.environ.get("CONDA_DEFAULT_ENV"), help="Name of the image to run.",
help="Name of the image to run. Defaults to the current environment",
) )
self.parser.add_argument( self.parser.add_argument(
"--env", "--env",

View file

@ -17,7 +17,7 @@ from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
get_provider_registry, get_provider_registry,
) )
from llama_stack.distribution.stack import replace_env_vars from llama_stack.distribution.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
@ -164,7 +164,8 @@ def upgrade_from_routing_table(
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig: def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None) version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION: if version == LLAMA_STACK_RUN_CONFIG_VERSION:
return StackRunConfig(**replace_env_vars(config_dict)) processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
if "routing_table" in config_dict: if "routing_table" in config_dict:
logger.info("Upgrading config...") logger.info("Upgrading config...")
@ -175,4 +176,5 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
if not config_dict.get("external_providers_dir", None): if not config_dict.get("external_providers_dir", None):
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
return StackRunConfig(**replace_env_vars(config_dict)) processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))

View file

@ -12,11 +12,13 @@ import os
import sys import sys
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Any, TypeVar, Union, get_args, get_origin from typing import Any, TypeVar, Union, get_args, get_origin
import httpx import httpx
import yaml import yaml
from fastapi import Response as FastAPIResponse
from llama_stack_client import ( from llama_stack_client import (
NOT_GIVEN, NOT_GIVEN,
APIResponse, APIResponse,
@ -112,6 +114,27 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
class LibraryClientUploadFile:
"""LibraryClient UploadFile object that mimics FastAPI's UploadFile interface."""
def __init__(self, filename: str, content: bytes):
self.filename = filename
self.content = content
self.content_type = "application/octet-stream"
async def read(self) -> bytes:
return self.content
class LibraryClientHttpxResponse:
"""LibraryClient httpx Response object for FastAPI Response conversion."""
def __init__(self, response):
self.content = response.body if isinstance(response.body, bytes) else response.body.encode()
self.status_code = response.status_code
self.headers = response.headers
class LlamaStackAsLibraryClient(LlamaStackClient): class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__( def __init__(
self, self,
@ -128,6 +151,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
self.skip_logger_removal = skip_logger_removal self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data self.provider_data = provider_data
self.loop = asyncio.new_event_loop()
def initialize(self): def initialize(self):
if in_notebook(): if in_notebook():
import nest_asyncio import nest_asyncio
@ -136,7 +161,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
if not self.skip_logger_removal: if not self.skip_logger_removal:
self._remove_root_logger_handlers() self._remove_root_logger_handlers()
return asyncio.run(self.async_client.initialize()) return self.loop.run_until_complete(self.async_client.initialize())
def _remove_root_logger_handlers(self): def _remove_root_logger_handlers(self):
""" """
@ -149,10 +174,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
logger.info(f"Removed handler {handler.__class__.__name__} from root logger") logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
# NOTE: We are using AsyncLlamaStackClient under the hood loop = self.loop
# A new event loop is needed to convert the AsyncStream
# from async client into SyncStream return type for streaming
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
if kwargs.get("stream"): if kwargs.get("stream"):
@ -169,7 +191,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
pending = asyncio.all_tasks(loop) pending = asyncio.all_tasks(loop)
if pending: if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
return sync_generator() return sync_generator()
else: else:
@ -179,7 +200,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
pending = asyncio.all_tasks(loop) pending = asyncio.all_tasks(loop)
if pending: if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
return result return result
@ -295,6 +315,31 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return response return response
def _handle_file_uploads(self, options: Any, body: dict) -> tuple[dict, list[str]]:
"""Handle file uploads from OpenAI client and add them to the request body."""
if not (hasattr(options, "files") and options.files):
return body, []
if not isinstance(options.files, list):
return body, []
field_names = []
for file_tuple in options.files:
if not (isinstance(file_tuple, tuple) and len(file_tuple) >= 2):
continue
field_name = file_tuple[0]
file_object = file_tuple[1]
if isinstance(file_object, BytesIO):
file_object.seek(0)
file_content = file_object.read()
filename = getattr(file_object, "name", "uploaded_file")
field_names.append(field_name)
body[field_name] = LibraryClientUploadFile(filename, file_content)
return body, field_names
async def _call_non_streaming( async def _call_non_streaming(
self, self,
*, *,
@ -310,15 +355,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls) matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body)
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
await start_trace(route, {"__location__": "library_client"}) await start_trace(route, {"__location__": "library_client"})
try: try:
result = await matched_func(**body) result = await matched_func(**body)
finally: finally:
await end_trace() await end_trace()
# Handle FastAPI Response objects (e.g., from file content retrieval)
if isinstance(result, FastAPIResponse):
return LibraryClientHttpxResponse(result)
json_content = json.dumps(convert_pydantic_to_json_value(result)) json_content = json.dumps(convert_pydantic_to_json_value(result))
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
mock_response = httpx.Response( mock_response = httpx.Response(
status_code=httpx.codes.OK, status_code=httpx.codes.OK,
content=json_content.encode("utf-8"), content=json_content.encode("utf-8"),
@ -330,7 +383,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url, url=options.url,
params=options.params, params=options.params,
headers=options.headers or {}, headers=options.headers or {},
json=convert_pydantic_to_json_value(body), json=convert_pydantic_to_json_value(filtered_body),
), ),
) )
response = APIResponse( response = APIResponse(
@ -404,13 +457,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return await response.parse() return await response.parse()
def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict: def _convert_body(
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
) -> dict:
if not body: if not body:
return {} return {}
if self.route_impls is None: if self.route_impls is None:
raise ValueError("Client not initialized") raise ValueError("Client not initialized")
exclude_params = exclude_params or set()
func, _, _ = find_matching_route(method, path, self.route_impls) func, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func) sig = inspect.signature(func)
@ -422,6 +479,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
for param_name, param in sig.parameters.items(): for param_name, param in sig.parameters.items():
if param_name in body: if param_name in body:
value = body.get(param_name) value = body.get(param_name)
converted_body[param_name] = convert_to_pydantic(param.annotation, value) if param_name in exclude_params:
converted_body[param_name] = value
else:
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
return converted_body return converted_body

View file

@ -200,7 +200,7 @@ def validate_and_prepare_providers(
specs = {} specs = {}
for provider in providers: for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__": if not provider.provider_id or provider.provider_id == "__disabled__":
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled") logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue continue
validate_provider(provider, api, provider_registry) validate_provider(provider, api, provider_registry)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import uuid
from typing import Any from typing import Any
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -81,6 +82,7 @@ class VectorIORouter(VectorIO):
embedding_model: str, embedding_model: str,
embedding_dimension: int | None = 384, embedding_dimension: int | None = 384,
provider_id: str | None = None, provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None, provider_vector_db_id: str | None = None,
) -> None: ) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
@ -89,6 +91,7 @@ class VectorIORouter(VectorIO):
embedding_model, embedding_model,
embedding_dimension, embedding_dimension,
provider_id, provider_id,
vector_db_name,
provider_vector_db_id, provider_vector_db_id,
) )
@ -123,7 +126,6 @@ class VectorIORouter(VectorIO):
embedding_model: str | None = None, embedding_model: str | None = None,
embedding_dimension: int | None = None, embedding_dimension: int | None = None,
provider_id: str | None = None, provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject: ) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}") logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
@ -135,17 +137,17 @@ class VectorIORouter(VectorIO):
embedding_model, embedding_dimension = embedding_model_info embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}") logger.info(f"No embedding model specified, using first available: {embedding_model}")
vector_db_id = name vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db( registered_vector_db = await self.routing_table.register_vector_db(
vector_db_id, vector_db_id=vector_db_id,
embedding_model, embedding_model=embedding_model,
embedding_dimension, embedding_dimension=embedding_dimension,
provider_id, provider_id=provider_id,
provider_vector_db_id, provider_vector_db_id=vector_db_id,
vector_db_name=name,
) )
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store( return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
vector_db_id, name=name,
file_ids=file_ids, file_ids=file_ids,
expires_after=expires_after, expires_after=expires_after,
chunking_strategy=chunking_strategy, chunking_strategy=chunking_strategy,

View file

@ -80,3 +80,38 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if existing_model is None: if existing_model is None:
raise ValueError(f"Model {model_id} not found") raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model) await self.unregister_object(existing_model)
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
) -> None:
existing_models = await self.get_all_with_type("model")
# we may have an alias for the model registered by the user (or during initialization
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
# we leave embeddings models alone because often we don't get metadata
# (embedding dimension, etc.) from the provider
if model.provider_id == provider_id and model.model_type == ModelType.llm:
model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id]
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
ModelWithOwner(
identifier=model.identifier,
provider_resource_id=model.provider_resource_id,
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
)
)

View file

@ -36,6 +36,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
embedding_dimension: int | None = 384, embedding_dimension: int | None = 384,
provider_id: str | None = None, provider_id: str | None = None,
provider_vector_db_id: str | None = None, provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB: ) -> VectorDB:
if provider_vector_db_id is None: if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id provider_vector_db_id = vector_db_id
@ -62,6 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"provider_resource_id": provider_vector_db_id, "provider_resource_id": provider_vector_db_id,
"embedding_model": embedding_model, "embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"], "embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_db_name,
} }
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data) vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db) await self.register_object(vector_db)

View file

@ -47,6 +47,7 @@ from llama_stack.distribution.server.routes import (
initialize_route_impls, initialize_route_impls,
) )
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
cast_image_name_to_string,
construct_stack, construct_stack,
replace_env_vars, replace_env_vars,
validate_env_pair, validate_env_pair,
@ -439,14 +440,12 @@ def main(args: argparse.Namespace | None = None):
logger.error(f"Error: {str(e)}") logger.error(f"Error: {str(e)}")
sys.exit(1) sys.exit(1)
config = replace_env_vars(config_contents) config = replace_env_vars(config_contents)
config = StackRunConfig(**config) config = StackRunConfig(**cast_image_name_to_string(config))
# now that the logger is initialized, print the line about which type of config we are using. # now that the logger is initialized, print the line about which type of config we are using.
logger.info(log_line) logger.info(log_line)
logger.info("Run configuration:") _log_run_config(run_config=config)
safe_config = redact_sensitive_fields(config.model_dump(mode="json"))
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI( app = FastAPI(
lifespan=lifespan, lifespan=lifespan,
@ -454,6 +453,7 @@ def main(args: argparse.Namespace | None = None):
redoc_url="/redoc", redoc_url="/redoc",
openapi_url="/openapi.json", openapi_url="/openapi.json",
) )
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware) app.add_middleware(ClientVersionMiddleware)
@ -492,7 +492,13 @@ def main(args: argparse.Namespace | None = None):
) )
try: try:
impls = asyncio.run(construct_stack(config)) # Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e: except InvalidProviderError as e:
logger.error(f"Error: {str(e)}") logger.error(f"Error: {str(e)}")
sys.exit(1) sys.exit(1)
@ -590,7 +596,16 @@ def main(args: argparse.Namespace | None = None):
if ssl_config: if ssl_config:
uvicorn_config.update(ssl_config) uvicorn_config.update(ssl_config)
uvicorn.run(**uvicorn_config) # Run uvicorn in the existing event loop to preserve background tasks
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
clean_config = remove_disabled_providers(safe_config)
logger.info(yaml.dump(clean_config, indent=2))
def extract_path_params(route: str) -> list[str]: def extract_path_params(route: str) -> list[str]:
@ -601,5 +616,20 @@ def extract_path_params(route: str) -> list[str]:
return params return params
def remove_disabled_providers(obj):
if isinstance(obj, dict):
if (
obj.get("provider_id") == "__disabled__"
or obj.get("shield_id") == "__disabled__"
or obj.get("provider_model_id") == "__disabled__"
):
return None
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
elif isinstance(obj, list):
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
else:
return obj
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -172,7 +172,6 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
# Create a copy with resolved provider_id but original config # Create a copy with resolved provider_id but original config
disabled_provider = v.copy() disabled_provider = v.copy()
disabled_provider["provider_id"] = resolved_provider_id disabled_provider["provider_id"] = resolved_provider_id
result.append(disabled_provider)
continue continue
except EnvVarError: except EnvVarError:
# If we can't resolve the provider_id, continue with normal processing # If we can't resolve the provider_id, continue with normal processing
@ -267,6 +266,13 @@ def _convert_string_to_proper_type(value: str) -> Any:
return value return value
def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Ensure that any value for a key 'image_name' in a config_dict is a string"""
if "image_name" in config_dict and config_dict["image_name"] is not None:
config_dict["image_name"] = str(config_dict["image_name"])
return config_dict
def validate_env_pair(env_pair: str) -> tuple[str, str]: def validate_env_pair(env_pair: str) -> tuple[str, str]:
"""Validate and split an environment variable key-value pair.""" """Validate and split an environment variable key-value pair."""
try: try:

View file

@ -8,6 +8,7 @@ import io
import json import json
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
@ -184,16 +185,26 @@ class ChatFormat:
content = content[: -len("<|eom_id|>")] content = content[: -len("<|eom_id|>")]
stop_reason = StopReason.end_of_message stop_reason = StopReason.end_of_message
tool_name = None tool_name: str | BuiltinTool | None = None
tool_arguments = {} tool_arguments: dict[str, Any] = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content) custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None: if custom_tool_info is not None:
tool_name, tool_arguments = custom_tool_info # Type guard: ensure custom_tool_info is a tuple of correct types
if isinstance(custom_tool_info, tuple) and len(custom_tool_info) == 2:
extracted_tool_name, extracted_tool_arguments = custom_tool_info
# Handle both dict and str return types from the function
if isinstance(extracted_tool_arguments, dict):
tool_name, tool_arguments = extracted_tool_name, extracted_tool_arguments
else:
# If it's a string, treat it as a query parameter
tool_name, tool_arguments = extracted_tool_name, {"query": extracted_tool_arguments}
else:
tool_name, tool_arguments = None, {}
# Sometimes when agent has custom tools alongside builin tools # Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools # Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case # This code tries to handle that case
if tool_name in BuiltinTool.__members__: if tool_name is not None and tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name] tool_name = BuiltinTool[tool_name]
if isinstance(tool_arguments, dict): if isinstance(tool_arguments, dict):
tool_arguments = { tool_arguments = {

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import AccessRule, Api
from .config import LocalfsFilesImplConfig from .config import LocalfsFilesImplConfig
from .files import LocalfsFilesImpl from .files import LocalfsFilesImpl
@ -14,7 +14,7 @@ from .files import LocalfsFilesImpl
__all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"] __all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"]
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any]): async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
impl = LocalfsFilesImpl(config) impl = LocalfsFilesImpl(config, policy)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -19,16 +19,19 @@ from llama_stack.apis.files import (
OpenAIFileObject, OpenAIFileObject,
OpenAIFilePurpose, OpenAIFilePurpose,
) )
from llama_stack.distribution.datatypes import AccessRule
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from .config import LocalfsFilesImplConfig from .config import LocalfsFilesImplConfig
class LocalfsFilesImpl(Files): class LocalfsFilesImpl(Files):
def __init__(self, config: LocalfsFilesImplConfig) -> None: def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
self.config = config self.config = config
self.sql_store: SqlStore | None = None self.policy = policy
self.sql_store: AuthorizedSqlStore | None = None
async def initialize(self) -> None: async def initialize(self) -> None:
"""Initialize the files provider by setting up storage directory and metadata database.""" """Initialize the files provider by setting up storage directory and metadata database."""
@ -37,7 +40,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True) storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata # Initialize SQL store for metadata
self.sql_store = sqlstore_impl(self.config.metadata_store) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
await self.sql_store.create_table( await self.sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -51,6 +54,9 @@ class LocalfsFilesImpl(Files):
}, },
) )
async def shutdown(self) -> None:
pass
def _generate_file_id(self) -> str: def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API.""" """Generate a unique file ID for OpenAI API."""
return f"file-{uuid.uuid4().hex}" return f"file-{uuid.uuid4().hex}"
@ -123,6 +129,7 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None, where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
@ -153,7 +160,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store: if not self.sql_store:
raise RuntimeError("Files provider not initialized") raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row: if not row:
raise ValueError(f"File with id {file_id} not found") raise ValueError(f"File with id {file_id} not found")
@ -171,7 +178,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store: if not self.sql_store:
raise RuntimeError("Files provider not initialized") raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row: if not row:
raise ValueError(f"File with id {file_id} not found") raise ValueError(f"File with id {file_id} not found")
@ -194,7 +201,7 @@ class LocalfsFilesImpl(Files):
raise RuntimeError("Files provider not initialized") raise RuntimeError("Files provider not initialized")
# Get file metadata # Get file metadata
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
if not row: if not row:
raise ValueError(f"File with id {file_id} not found") raise ValueError(f"File with id {file_id} not found")

View file

@ -98,7 +98,7 @@ class ProcessingMessageWrapper(BaseModel):
def mp_rank_0() -> bool: def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0 return bool(get_model_parallel_rank() == 0)
def encode_msg(msg: ProcessingMessage) -> bytes: def encode_msg(msg: ProcessingMessage) -> bytes:
@ -125,7 +125,7 @@ def retrieve_requests(reply_socket_url: str):
reply_socket.send_multipart([client_id, encode_msg(obj)]) reply_socket.send_multipart([client_id, encode_msg(obj)])
while True: while True:
tasks = [None] tasks: list[ProcessingMessage | None] = [None]
if mp_rank_0(): if mp_rank_0():
client_id, maybe_task_json = maybe_get_work(reply_socket) client_id, maybe_task_json = maybe_get_work(reply_socket)
if maybe_task_json is not None: if maybe_task_json is not None:
@ -152,7 +152,7 @@ def retrieve_requests(reply_socket_url: str):
break break
for obj in out: for obj in out:
updates = [None] updates: list[ProcessingMessage | None] = [None]
if mp_rank_0(): if mp_rank_0():
_, update_json = maybe_get_work(reply_socket) _, update_json = maybe_get_work(reply_socket)
update = maybe_parse_message(update_json) update = maybe_parse_message(update_json)

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]):
from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -1,53 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class VLLMConfig(BaseModel):
"""Configuration for the vLLM inference provider.
Note that the model name is no longer part of this static configuration.
You can bind an instance of this provider to a specific model with the
``models.register()`` API call."""
tensor_parallel_size: int = Field(
default=1,
description="Number of tensor parallel replicas (number of GPUs to use).",
)
max_tokens: int = Field(
default=4096,
description="Maximum number of tokens to generate.",
)
max_model_len: int = Field(default=4096, description="Maximum context length to use during serving.")
max_num_seqs: int = Field(default=4, description="Maximum parallel batch size for generation.")
enforce_eager: bool = Field(
default=False,
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
)
gpu_memory_utilization: float = Field(
default=0.3,
description=(
"How much GPU memory will be allocated when this provider has finished "
"loading, including memory that was already allocated before loading."
),
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:=1}",
"max_tokens": "${env.MAX_TOKENS:=4096}",
"max_model_len": "${env.MAX_MODEL_LEN:=4096}",
"max_num_seqs": "${env.MAX_NUM_SEQS:=4}",
"enforce_eager": "${env.ENFORCE_EAGER:=False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:=0.3}",
}

View file

@ -1,170 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import vllm
from llama_stack.apis.inference import (
ChatCompletionRequest,
GrammarResponseFormat,
JsonSchemaResponseFormat,
Message,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
)
###############################################################################
# This file contains OpenAI compatibility code that is currently only used
# by the inline vLLM connector. Some or all of this code may be moved to a
# central location at a later date.
def _merge_context_into_content(message: Message) -> Message: # type: ignore
"""
Merge the ``context`` field of a Llama Stack ``Message`` object into
the content field for compabilitiy with OpenAI-style APIs.
Generates a content string that emulates the current behavior
of ``llama_models.llama3.api.chat_format.encode_message()``.
:param message: Message that may include ``context`` field
:returns: A version of ``message`` with any context merged into the
``content`` field.
"""
if not isinstance(message, UserMessage): # Separate type check for linter
return message
if message.context is None:
return message
return UserMessage(
role=message.role,
# Emumate llama_models.llama3.api.chat_format.encode_message()
content=message.content + "\n\n" + message.context,
context=None,
)
def _llama_stack_tools_to_openai_tools(
tools: list[ToolDefinition] | None = None,
) -> list[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
"""
Convert the list of available tools from Llama Stack's format to vLLM's
version of OpenAI's format.
"""
if tools is None:
return []
result = []
for t in tools:
if isinstance(t.tool_name, BuiltinTool):
raise NotImplementedError("Built-in tools not yet implemented")
if t.parameters is None:
parameters = None
else: # if t.parameters is not None
# Convert the "required" flags to a list of required params
required_params = [k for k, v in t.parameters.items() if v.required]
parameters = {
"type": "object", # Mystery value that shows up in OpenAI docs
"properties": {
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
},
"required": required_params,
}
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
name=t.tool_name, description=t.description, parameters=parameters
)
# Every tool definition is double-boxed in a ChatCompletionToolsParam
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
return result
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
request: ChatCompletionRequest,
) -> dict:
"""
Convert a chat completion request in Llama Stack format into an
equivalent set of arguments to pass to an OpenAI-compatible
chat completions API.
:param request: Bundled request parameters in Llama Stack format.
:returns: Dictionary of key-value pairs to use as an initializer
for a dataclass or to be converted directly to JSON and sent
over the wire.
"""
converted_messages = [
# This mystery async call makes the parent function also be async
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True)
for m in request.messages
]
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
# Llama will try to use built-in tools with no tool catalog, so don't enable
# tool choice unless at least one tool is enabled.
converted_tool_choice = "none"
if (
request.tool_config is not None
and request.tool_config.tool_choice == ToolChoice.auto
and request.tools is not None
and len(request.tools) > 0
):
converted_tool_choice = "auto"
# TODO: Figure out what to do with the tool_prompt_format argument.
# Other connectors appear to drop it quietly.
# Use Llama Stack shared code to translate sampling parameters.
sampling_options = get_sampling_options(request.sampling_params)
# get_sampling_options() translates repetition penalties to an option that
# OpenAI's APIs don't know about.
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
# For now, translate repetition penalties into a format that vLLM's broken
# API will handle correctly. Two wrongs make a right...
if "repeat_penalty" in sampling_options:
del sampling_options["repeat_penalty"]
if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0:
sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty
# Convert a single response format into four different parameters, per
# the OpenAI spec
guided_decoding_options = dict()
if request.response_format is None:
# Use defaults
pass
elif isinstance(request.response_format, JsonSchemaResponseFormat):
guided_decoding_options["guided_json"] = request.response_format.json_schema
elif isinstance(request.response_format, GrammarResponseFormat):
guided_decoding_options["guided_grammar"] = request.response_format.bnf
else:
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'")
logprob_options = dict()
if request.logprobs is not None:
logprob_options["logprobs"] = request.logprobs.top_k
# Marshall together all the arguments for a ChatCompletionRequest
request_options = {
"model": request.model,
"messages": converted_messages,
"tools": converted_tools,
"tool_choice": converted_tool_choice,
"stream": request.stream,
**sampling_options,
**guided_decoding_options,
**logprob_options,
}
return request_options

View file

@ -1,811 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
# These vLLM modules contain names that overlap with Llama Stack names, so we import
# fully-qualified names
import vllm.entrypoints.openai.protocol
import vllm.sampling_params
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextDelta,
ToolCallDelta,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
TokenLogProbs,
ToolChoice,
ToolConfig,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama import sku_list
from llama_stack.models.llama.datatypes import (
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionToLlamaStackMixin,
get_stop_reason,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import VLLMConfig
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
# Map from Hugging Face model architecture name to appropriate tool parser.
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
# available parsers.
# TODO: Expand this list
CONFIG_TYPE_TO_TOOL_PARSER = {
"GraniteConfig": "granite",
"MllamaConfig": "llama3_json",
"LlamaConfig": "llama3_json",
}
DEFAULT_TOOL_PARSER = "pythonic"
logger = get_logger(__name__, category="inference")
def _random_uuid_str() -> str:
return str(uuid.uuid4().hex)
def _response_format_to_guided_decoding_params(
response_format: ResponseFormat | None, # type: ignore
) -> vllm.sampling_params.GuidedDecodingParams:
"""
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
indicating no constraints.
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
"""
if response_format is None:
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
# value that crashes the executor on some code paths. Use ``None`` instead.
return None
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
# Translate the types that exist and detect if Llama Stack adds new ones.
if isinstance(response_format, JsonSchemaResponseFormat):
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
elif isinstance(response_format, GrammarResponseFormat):
# BNF grammar.
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
# representation of the grammar.
raise TypeError(
"Constrained decoding with BNF grammars is not currently implemented, because the "
"reference implementation does not implement it."
)
else:
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
def _convert_sampling_params(
sampling_params: SamplingParams | None,
response_format: ResponseFormat | None, # type: ignore
log_prob_config: LogProbConfig | None,
) -> vllm.SamplingParams:
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
format."""
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
# Stack dataclasses. These defaults are different from vLLM's defaults.
if sampling_params is None:
sampling_params = SamplingParams()
if log_prob_config is None:
log_prob_config = LogProbConfig()
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
if sampling_params.strategy.top_k == 0:
# vLLM treats "k" differently for top-k sampling
vllm_top_k = -1
else:
vllm_top_k = sampling_params.strategy.top_k
else:
vllm_top_k = -1
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
vllm_top_p = sampling_params.strategy.top_p
# Llama Stack only allows temperature with top-P.
vllm_temperature = sampling_params.strategy.temperature
else:
vllm_top_p = 1.0
vllm_temperature = 0.0
# vLLM allows top-p and top-k at the same time.
vllm_sampling_params = vllm.SamplingParams.from_optional(
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
temperature=vllm_temperature,
top_p=vllm_top_p,
top_k=vllm_top_k,
repetition_penalty=sampling_params.repetition_penalty,
guided_decoding=_response_format_to_guided_decoding_params(response_format),
logprobs=log_prob_config.top_k,
)
return vllm_sampling_params
class VLLMInferenceImpl(
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
"""
vLLM-based inference model adapter for Llama Stack with support for multiple models.
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
"""
config: VLLMConfig
register_helper: ModelRegistryHelper
model_ids: set[str]
resolved_model_id: str | None
engine: AsyncLLMEngine | None
chat: OpenAIServingChat | None
is_meta_llama_model: bool
def __init__(self, config: VLLMConfig):
self.config = config
logger.info(f"Config is: {self.config}")
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.formatter = ChatFormat(Tokenizer.get_instance())
# The following are initialized when paths are bound to this provider
self.resolved_model_id = None
self.model_ids = set()
self.engine = None
self.chat = None
self.is_meta_llama_model = False
###########################################################################
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
async def initialize(self) -> None:
"""
Callback that is invoked through many levels of indirection during provider class
instantiation, sometime after when __init__() is called and before any model registration
methods or methods connected to a REST API are called.
It's not clear what assumptions the class can make about the platform's initialization
state here that can't be made during __init__(), and vLLM can't be started until we know
what model it's supposed to be serving, so nothing happens here currently.
"""
pass
async def shutdown(self) -> None:
logger.info(f"Shutting down inline vLLM inference provider {self}.")
if self.engine is not None:
self.engine.shutdown_background_loop()
self.engine = None
self.chat = None
self.model_ids = set()
self.resolved_model_id = None
###########################################################################
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
# Note that the return type of the superclass method is WRONG
async def register_model(self, model: Model) -> Model:
"""
Callback that is called when the server associates an inference endpoint with an
inference provider.
:param model: Object that encapsulates parameters necessary for identifying a specific
LLM.
:returns: The input ``Model`` object. It may or may not be permissible to change fields
before returning this object.
"""
logger.debug(f"In register_model({model})")
# First attempt to interpret the model coordinates as a Llama model name
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
if resolved_llama_model is not None:
# Load from Hugging Face repo into default local cache dir
model_id_for_vllm = resolved_llama_model.huggingface_repo
# Detect a genuine Meta Llama model to trigger Meta-specific preprocessing.
# Don't set self.is_meta_llama_model until we actually load the model.
is_meta_llama_model = True
else: # if resolved_llama_model is None
# Not a Llama model name. Pass the model id through to vLLM's loader
model_id_for_vllm = model.provider_model_id
is_meta_llama_model = False
if self.resolved_model_id is not None:
if model_id_for_vllm != self.resolved_model_id:
raise ValueError(
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple "
f"copies of the provider instead."
)
else:
# Model already loaded
logger.info(
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
)
self.model_ids.add(model.model_id)
return model
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
if is_meta_llama_model:
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
self.is_meta_llama_model = is_meta_llama_model
# If we get here, this is the first time registering a model.
# Preload so that the first inference request won't time out.
engine_args = AsyncEngineArgs(
model=model_id_for_vllm,
tokenizer=model_id_for_vllm,
tensor_parallel_size=self.config.tensor_parallel_size,
enforce_eager=self.config.enforce_eager,
gpu_memory_utilization=self.config.gpu_memory_utilization,
max_num_seqs=self.config.max_num_seqs,
max_model_len=self.config.max_model_len,
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
# parser, we need to determine what model architecture is being used. For now, we infer
# that information from what config class the model uses.
low_level_model_config = self.engine.engine.get_model_config()
hf_config = low_level_model_config.hf_config
hf_config_class_name = hf_config.__class__.__name__
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
else:
# No info -- choose a default so we can at least attempt tool
# use.
tool_parser = DEFAULT_TOOL_PARSER
logger.debug(f"{hf_config_class_name=}")
logger.debug(f"{tool_parser=}")
# Wrap the lower-level engine in an OpenAI-compatible chat API
model_config = await self.engine.get_model_config()
self.chat = OpenAIServingChat(
engine_client=self.engine,
model_config=model_config,
models=OpenAIServingModels(
engine_client=self.engine,
model_config=model_config,
base_model_paths=[
# The layer below us will only see resolved model IDs
BaseModelPath(model_id_for_vllm, model_id_for_vllm)
],
),
response_role="assistant",
request_logger=None, # Use default logging
chat_template=None, # Use default template from model checkpoint
enable_auto_tools=True,
tool_parser=tool_parser,
chat_template_content_format="auto",
)
self.resolved_model_id = model_id_for_vllm
self.model_ids.add(model.model_id)
logger.info(f"Finished preloading model: {model_id_for_vllm}")
return model
async def unregister_model(self, model_id: str) -> None:
"""
Callback that is called when the server removes an inference endpoint from an inference
provider.
:param model_id: The same external ID that the higher layers of the stack previously passed
to :func:`register_model()`
"""
if model_id not in self.model_ids:
raise ValueError(
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
)
self.model_ids.remove(model_id)
if len(self.model_ids) == 0:
# Last model was just unregistered. Shut down the connection to vLLM and free up
# resources.
# Note that this operation may cause in-flight chat completion requests on the
# now-unregistered model to return errors.
self.resolved_model_id = None
self.chat = None
self.engine.shutdown_background_loop()
self.engine = None
###########################################################################
# METHODS INHERITED FROM Inference INTERFACE
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
if not isinstance(content, str):
raise NotImplementedError("Multimodal input not currently supported")
if sampling_params is None:
sampling_params = SamplingParams()
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
logger.debug(f"{converted_sampling_params=}")
if stream:
return self._streaming_completion(content, converted_sampling_params)
else:
streaming_result = None
async for _ in self._streaming_completion(content, converted_sampling_params):
pass
return CompletionResponse(
content=streaming_result.delta,
stop_reason=streaming_result.stop_reason,
logprobs=streaming_result.logprobs,
)
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: list[Message], # type: ignore
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None, # type: ignore
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
sampling_params = sampling_params or SamplingParams()
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
# Convert to Llama Stack internal format for consistency
request = ChatCompletionRequest(
model=self.resolved_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
if self.is_meta_llama_model:
# Bypass vLLM chat templating layer for Meta Llama models, because the
# templating layer in Llama Stack currently produces better results.
logger.debug(
f"Routing {self.resolved_model_id} chat completion through "
f"Llama Stack's templating layer instead of vLLM's."
)
return await self._chat_completion_for_meta_llama(request)
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
logger.debug(f"Converted request: {chat_completion_request}")
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
logger.debug(f"Result from vLLM: {vllm_result}")
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
raise ValueError(f"Error from vLLM layer: {vllm_result}")
# Return type depends on "stream" argument
if stream:
if not isinstance(vllm_result, AsyncGenerator):
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
# vLLM client returns a stream of strings, which need to be parsed.
# Stream comes in the form of an async generator.
return self._convert_streaming_results(vllm_result)
else:
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
return self._convert_non_streaming_results(vllm_result)
###########################################################################
# INTERNAL METHODS
async def _streaming_completion(
self, content: str, sampling_params: vllm.SamplingParams
) -> AsyncIterator[CompletionResponseStreamChunk]:
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
that arguments have been validated upstream.
:param content: Must be a string
:param sampling_params: Paramters from public API's ``response_format``
and ``sampling_params`` arguments, converted to VLLM format
"""
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
# layer, because doing so simplifies the code here.
# The vLLM engine requires a unique identifier for each call to generate()
request_id = _random_uuid_str()
# The vLLM generate() API is streaming-only and returns an async generator.
# The generator returns objects of type vllm.RequestOutput.
results_generator = self.engine.generate(content, sampling_params, request_id)
# Need to know the model's EOS token ID for the conversion code below.
# AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if
# we drill down to the LLMEngine inside the AsyncLLMEngine.
# Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup,
# and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup.
llm_engine = self.engine.engine
tokenizer_group = llm_engine.tokenizer
eos_token_id = tokenizer_group.tokenizer.eos_token_id
request_output: vllm.RequestOutput = None
async for request_output in results_generator:
# Check for weird inference failures
if request_output.outputs is None or len(request_output.outputs) == 0:
# This case also should never happen
raise ValueError("Inference produced empty result")
# If we get here, then request_output contains the final output of the generate() call.
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
# us to return one.
output: vllm.CompletionOutput = request_output.outputs[0]
completion_string = output.text
# Convert logprobs from vLLM's format to Llama Stack's format
logprobs = [
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
for logprob_dict in output.logprobs
]
# The final output chunk should be labeled with the reason that the overall generate()
# call completed.
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
if output.stop_reason is None:
stop_reason = None # Still going
elif output.stop_reason == "stop":
stop_reason = StopReason.end_of_turn
elif output.stop_reason == "length":
stop_reason = StopReason.out_of_tokens
elif isinstance(output.stop_reason, int):
# If the model config specifies multiple end-of-sequence tokens, then vLLM
# will return the token ID of the EOS token in the stop_reason field.
stop_reason = StopReason.end_of_turn
else:
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
# some reason.
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
stop_reason = StopReason.end_of_message
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
# provide one if it runs out of tokens.
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta=completion_string,
stop_reason=StopReason.out_of_tokens,
logprobs=logprobs,
)
def _convert_non_streaming_results(
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
) -> ChatCompletionResponse:
"""
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
equivalent Llama Stack object.
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
the fields that aren't currently present in the Llama Stack dataclass.
"""
# There may be multiple responses, but we can only pass through the first one.
if len(vllm_result.choices) == 0:
raise ValueError("Don't know how to convert response object without any responses")
vllm_message = vllm_result.choices[0].message
vllm_finish_reason = vllm_result.choices[0].finish_reason
converted_message = CompletionMessage(
role=vllm_message.role,
# Llama Stack API won't accept None for content field.
content=("" if vllm_message.content is None else vllm_message.content),
stop_reason=get_stop_reason(vllm_finish_reason),
tool_calls=[
ToolCall(
call_id=t.id,
tool_name=t.function.name,
# vLLM function args come back as a string. Llama Stack expects JSON.
arguments=json.loads(t.function.arguments),
arguments_json=t.function.arguments,
)
for t in vllm_message.tool_calls
],
)
# TODO: Convert logprobs
logger.debug(f"Converted message: {converted_message}")
return ChatCompletionResponse(
completion_message=converted_message,
)
async def _chat_completion_for_meta_llama(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
"""
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
chat template instead of using vLLM's version of that template. The Llama Stack version
of the chat template currently produces more reliable outputs.
Once vLLM's support for Meta Llama models has matured more, we should consider routing
Meta Llama requests through the vLLM chat completions API instead of using this method.
"""
formatter = ChatFormat(Tokenizer.get_instance())
# Note that this function call modifies `request` in place.
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
model_id = list(self.model_ids)[0] # Any model ID will do here
completion_response_or_iterator = await self.completion(
model_id=model_id,
content=prompt,
sampling_params=request.sampling_params,
response_format=request.response_format,
stream=request.stream,
logprobs=request.logprobs,
)
if request.stream:
if not isinstance(completion_response_or_iterator, AsyncIterator):
raise TypeError(
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
)
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
# elsif not request.stream:
if not isinstance(completion_response_or_iterator, CompletionResponse):
raise TypeError(
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
)
completion_response: CompletionResponse = completion_response_or_iterator
raw_message = formatter.decode_assistant_message_from_content(
completion_response.content, completion_response.stop_reason
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=completion_response.logprobs,
)
async def _chat_completion_for_meta_llama_streaming(
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
) -> AsyncIterator:
"""
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
method to keep asyncio happy.
"""
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
async def _generate_and_convert_to_openai_compat():
chunk: CompletionResponseStreamChunk # Make Pylance happy
last_text_len = 0
async for chunk in results_iterator:
if chunk.stop_reason == StopReason.end_of_turn:
finish_reason = "stop"
elif chunk.stop_reason == StopReason.end_of_message:
finish_reason = "eos"
elif chunk.stop_reason == StopReason.out_of_tokens:
finish_reason = "length"
else:
finish_reason = None
# Convert delta back to an actual delta
text_delta = chunk.delta[last_text_len:]
last_text_len = len(chunk.delta)
logger.debug(f"{text_delta=}; {finish_reason=}")
yield OpenAICompatCompletionResponse(
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
logger.debug(f"Returning chunk: {chunk}")
yield chunk
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
"""
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
API into a second async iterator that returns Llama Stack objects.
:param vllm_result: Stream of strings that need to be parsed
"""
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
# those chunks and output them at the end.
# This data structure holds the current set of partial tool calls.
index_to_tool_call: dict[int, dict] = dict()
# The Llama Stack event stream must always start with a start event. Use an empty one to
# simplify logic below
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
stop_reason=None,
)
)
converted_stop_reason = None
async for chunk_str in vllm_result:
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
# end with "\n\n".
_prefix = "data: "
_suffix = "\n\n"
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
# In between the "data: " and newlines is an event record
data_str = chunk_str[len(_prefix) : -len(_suffix)]
# The end of the stream is indicated with "[DONE]"
if data_str == "[DONE]":
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=converted_stop_reason,
)
)
return
# Anything that is not "[DONE]" should be a JSON record
parsed_chunk = json.loads(data_str)
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
# The result may contain multiple completions, but Llama Stack APIs only support
# returning one.
first_choice = parsed_chunk["choices"][0]
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
delta_record = first_choice["delta"]
if "content" in delta_record:
# Text delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=delta_record["content"]),
stop_reason=converted_stop_reason,
)
)
elif "tool_calls" in delta_record:
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
# calls, so buffer until we get a "tool calls" stop reason
for tc in delta_record["tool_calls"]:
index = tc["index"]
if index not in index_to_tool_call:
# First time this tool call is showing up
index_to_tool_call[index] = dict()
tool_call = index_to_tool_call[index]
if "id" in tc:
tool_call["call_id"] = tc["id"]
if "function" in tc:
if "name" in tc["function"]:
tool_call["tool_name"] = tc["function"]["name"]
if "arguments" in tc["function"]:
# Arguments comes in as pieces of a string
if "arguments_str" not in tool_call:
tool_call["arguments_str"] = ""
tool_call["arguments_str"] += tc["function"]["arguments"]
else:
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
if first_choice["finish_reason"] == "tool_calls":
# Special OpenAI code for "tool calls complete".
# Output the buffered tool calls. Llama Stack requires a separate event per tool
# call.
for tool_call_record in index_to_tool_call.values():
# Arguments come in as a string. Parse the completed string.
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
del tool_call_record["arguments_str"]
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
stop_reason=converted_stop_reason,
)
)
# If we get here, we've lost the connection with the vLLM event stream before it ended
# normally.
raise ValueError("vLLM event stream ended without [DONE] message.")

View file

@ -181,8 +181,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
) )
self.cache[vector_db.identifier] = index self.cache[vector_db.identifier] = index
# Load existing OpenAI vector stores using the mixin method # Load existing OpenAI vector stores into the in-memory cache
self.openai_vector_stores = await self._load_openai_vector_stores() await self.initialize_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
# Cleanup if needed # Cleanup if needed
@ -261,42 +261,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from kvstore."""
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
stores = {}
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
if store_id in self.openai_vector_stores:
del self.openai_vector_stores[store_id]
async def _save_openai_vector_store_file( async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None: ) -> None:

View file

@ -7,6 +7,7 @@
import asyncio import asyncio
import json import json
import logging import logging
import re
import sqlite3 import sqlite3
import struct import struct
from typing import Any from typing import Any
@ -117,6 +118,10 @@ def _rrf_rerank(
return rrf_scores return rrf_scores
def _make_sql_identifier(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
class SQLiteVecIndex(EmbeddingIndex): class SQLiteVecIndex(EmbeddingIndex):
""" """
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec. An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
@ -130,9 +135,9 @@ class SQLiteVecIndex(EmbeddingIndex):
self.dimension = dimension self.dimension = dimension
self.db_path = db_path self.db_path = db_path
self.bank_id = bank_id self.bank_id = bank_id
self.metadata_table = f"chunks_{bank_id}".replace("-", "_") self.metadata_table = _make_sql_identifier(f"chunks_{bank_id}")
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") self.vector_table = _make_sql_identifier(f"vec_chunks_{bank_id}")
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_") self.fts_table = _make_sql_identifier(f"fts_chunks_{bank_id}")
self.kvstore = kvstore self.kvstore = kvstore
@classmethod @classmethod
@ -148,14 +153,14 @@ class SQLiteVecIndex(EmbeddingIndex):
try: try:
# Create the table to store chunk metadata. # Create the table to store chunk metadata.
cur.execute(f""" cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} ( CREATE TABLE IF NOT EXISTS [{self.metadata_table}] (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
chunk TEXT chunk TEXT
); );
""") """)
# Create the virtual table for embeddings. # Create the virtual table for embeddings.
cur.execute(f""" cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table} CREATE VIRTUAL TABLE IF NOT EXISTS [{self.vector_table}]
USING vec0(embedding FLOAT[{self.dimension}], id TEXT); USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
""") """)
connection.commit() connection.commit()
@ -163,7 +168,7 @@ class SQLiteVecIndex(EmbeddingIndex):
# based on query. Implementation of the change on client side will allow passing the search_mode option # based on query. Implementation of the change on client side will allow passing the search_mode option
# during initialization to make it easier to create the table that is required. # during initialization to make it easier to create the table that is required.
cur.execute(f""" cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table} CREATE VIRTUAL TABLE IF NOT EXISTS [{self.fts_table}]
USING fts5(id, content); USING fts5(id, content);
""") """)
connection.commit() connection.commit()
@ -178,9 +183,9 @@ class SQLiteVecIndex(EmbeddingIndex):
connection = _create_sqlite_connection(self.db_path) connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor() cur = connection.cursor()
try: try:
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};") cur.execute(f"DROP TABLE IF EXISTS [{self.metadata_table}];")
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") cur.execute(f"DROP TABLE IF EXISTS [{self.vector_table}];")
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};") cur.execute(f"DROP TABLE IF EXISTS [{self.fts_table}];")
connection.commit() connection.commit()
finally: finally:
cur.close() cur.close()
@ -212,7 +217,7 @@ class SQLiteVecIndex(EmbeddingIndex):
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks] metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
cur.executemany( cur.executemany(
f""" f"""
INSERT INTO {self.metadata_table} (id, chunk) INSERT INTO [{self.metadata_table}] (id, chunk)
VALUES (?, ?) VALUES (?, ?)
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk; ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
""", """,
@ -230,7 +235,7 @@ class SQLiteVecIndex(EmbeddingIndex):
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True) for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
] ]
cur.executemany( cur.executemany(
f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
embedding_data, embedding_data,
) )
@ -238,13 +243,13 @@ class SQLiteVecIndex(EmbeddingIndex):
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks] fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT) # DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
cur.executemany( cur.executemany(
f"DELETE FROM {self.fts_table} WHERE id = ?;", f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
[(row[0],) for row in fts_data], [(row[0],) for row in fts_data],
) )
# INSERT new entries # INSERT new entries
cur.executemany( cur.executemany(
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);", f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
fts_data, fts_data,
) )
@ -280,8 +285,8 @@ class SQLiteVecIndex(EmbeddingIndex):
emb_blob = serialize_vector(emb_list) emb_blob = serialize_vector(emb_list)
query_sql = f""" query_sql = f"""
SELECT m.id, m.chunk, v.distance SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v FROM [{self.vector_table}] AS v
JOIN {self.metadata_table} AS m ON m.id = v.id JOIN [{self.metadata_table}] AS m ON m.id = v.id
WHERE v.embedding MATCH ? AND k = ? WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance; ORDER BY v.distance;
""" """
@ -322,9 +327,9 @@ class SQLiteVecIndex(EmbeddingIndex):
cur = connection.cursor() cur = connection.cursor()
try: try:
query_sql = f""" query_sql = f"""
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score SELECT DISTINCT m.id, m.chunk, bm25([{self.fts_table}]) AS score
FROM {self.fts_table} AS f FROM [{self.fts_table}] AS f
JOIN {self.metadata_table} AS m ON m.id = f.id JOIN [{self.metadata_table}] AS m ON m.id = f.id
WHERE f.content MATCH ? WHERE f.content MATCH ?
ORDER BY score ASC ORDER BY score ASC
LIMIT ?; LIMIT ?;
@ -452,8 +457,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
) )
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# load any existing OpenAI vector stores # Load existing OpenAI vector stores into the in-memory cache
self.openai_vector_stores = await self._load_openai_vector_stores() await self.initialize_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection # nothing to do since we don't maintain a persistent connection
@ -501,41 +506,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await self.cache[vector_db_id].index.delete() await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id] del self.cache[vector_db_id]
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from SQLite database."""
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
stores = {}
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
if store_id in self.openai_vector_stores:
del self.openai_vector_stores[store_id]
async def _save_openai_vector_store_file( async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None: ) -> None:

View file

@ -37,16 +37,6 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig", config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
description="Meta's reference implementation of inference with support for various model formats and optimization techniques.", description="Meta's reference implementation of inference with support for various model formats and optimization techniques.",
), ),
InlineProviderSpec(
api=Api.inference,
provider_type="inline::vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.inline.inference.vllm",
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
description="vLLM inference provider for high-performance model serving with PagedAttention and continuous batching.",
),
InlineProviderSpec( InlineProviderSpec(
api=Api.inference, api=Api.inference,
provider_type="inline::sentence-transformers", provider_type="inline::sentence-transformers",

View file

@ -3,16 +3,17 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import logging
from llama_stack.providers.remote.inference.llama_openai_compat.config import ( from llama_api_client import AsyncLlamaAPIClient, NotFoundError
LlamaCompatConfig,
) from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import ( from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
LiteLLMOpenAIMixin,
)
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: LlamaCompatConfig _config: LlamaCompatConfig
@ -27,8 +28,32 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
) )
self.config = config self.config = config
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from Llama API.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
llama_api_client = self._get_llama_api_client()
retrieved_model = await llama_api_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from Llama API")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from Llama API")
return False
except Exception as e:
logger.error(f"Failed to check model availability from Llama API: {e}")
return False
async def initialize(self): async def initialize(self):
await super().initialize() await super().initialize()
async def shutdown(self): async def shutdown(self):
await super().shutdown() await super().shutdown()
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)

View file

@ -7,10 +7,9 @@
import logging import logging
import warnings import warnings
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from functools import lru_cache
from typing import Any from typing import Any
from openai import APIConnectionError, AsyncOpenAI, BadRequestError from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -41,11 +40,7 @@ from llama_stack.apis.inference import (
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
) )
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
@ -93,41 +88,37 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self._config = config self._config = config
@lru_cache # noqa: B019 async def check_model_availability(self, model: str) -> bool:
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
""" """
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However, Check if a specific model is available.
some models are hosted on different URLs. This function returns the appropriate client
for the given provider_model_id.
This relies on lru_cache and self._default_client to avoid creating a new client for each request :param model: The model identifier to check.
or for each model that is hosted on https://integrate.api.nvidia.com/v1. :return: True if the model is available dynamically, False otherwise.
"""
try:
await self._client.models.retrieve(model)
return True
except NotFoundError:
logger.error(f"Model {model} is not available")
except Exception as e:
logger.error(f"Failed to check model availability: {e}")
return False
@property
def _client(self) -> AsyncOpenAI:
"""
Returns an OpenAI client for the configured NVIDIA API endpoint.
:param provider_model_id: The provider model ID
:return: An OpenAI client :return: An OpenAI client
""" """
@lru_cache # noqa: B019
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
"""
Maintain a single OpenAI client per base_url.
"""
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
special_model_urls = {
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls: return AsyncOpenAI(
base_url = special_model_urls[provider_model_id] base_url=base_url,
return _get_client_for_base_url(base_url) api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
async def _get_provider_model_id(self, model_id: str) -> str: async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store: if not self.model_store:
@ -169,7 +160,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
) )
try: try:
response = await self._get_client(provider_model_id).completions.create(**request) response = await self._client.completions.create(**request)
except APIConnectionError as e: except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -222,7 +213,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type] extra_body["input_type"] = task_type_options[task_type]
try: try:
response = await self._get_client(provider_model_id).embeddings.create( response = await self._client.embeddings.create(
model=provider_model_id, model=provider_model_id,
input=input, input=input,
extra_body=extra_body, extra_body=extra_body,
@ -283,7 +274,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
) )
try: try:
response = await self._get_client(provider_model_id).chat.completions.create(**request) response = await self._client.chat.completions.create(**request)
except APIConnectionError as e: except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -339,7 +330,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
) )
try: try:
return await self._get_client(provider_model_id).completions.create(**params) return await self._client.completions.create(**params)
except APIConnectionError as e: except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -398,47 +389,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
) )
try: try:
return await self._get_client(provider_model_id).chat.completions.create(**params) return await self._client.chat.completions.create(**params)
except APIConnectionError as e: except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def register_model(self, model: Model) -> Model:
"""
Allow non-llama model registration.
Non-llama model registration: API Catalogue models, post-training models, etc.
client = LlamaStackAsLibraryClient("nvidia")
client.models.register(
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
model_type=ModelType.llm,
provider_id="nvidia",
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
)
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
"""
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
llama_model = model.metadata.get("llama_model")
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
# not llama model
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
else:
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
return model

View file

@ -6,13 +6,15 @@
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field
DEFAULT_OLLAMA_URL = "http://localhost:11434" DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel): class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
@classmethod @classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:

View file

@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_entry, build_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16", "llama3.1:8b-instruct-fp16",
@ -73,16 +86,6 @@ MODEL_ENTRIES = [
"llama3.3:70b", "llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="all-minilm:l6-v2", provider_model_id="all-minilm:l6-v2",
aliases=["all-minilm"], aliases=["all-minilm"],
@ -100,4 +103,4 @@ MODEL_ENTRIES = [
"context_length": 8192, "context_length": 8192,
}, },
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import base64 import base64
import uuid import uuid
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
@ -89,23 +90,88 @@ class OllamaInferenceAdapter(
InferenceProvider, InferenceProvider,
ModelRegistryHelper, ModelRegistryHelper,
): ):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
def __init__(self, config: OllamaImplConfig) -> None: def __init__(self, config: OllamaImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.url = config.url self.config = config
self._client = None
self._openai_client = None
@property @property
def client(self) -> AsyncClient: def client(self) -> AsyncClient:
return AsyncClient(host=self.url) if self._client is None:
self._client = AsyncClient(host=self.config.url)
return self._client
@property @property
def openai_client(self) -> AsyncOpenAI: def openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama") if self._openai_client is None:
self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama")
return self._openai_client
async def initialize(self) -> None: async def initialize(self) -> None:
logger.debug(f"checking connectivity to Ollama at `{self.url}`...") logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
health_response = await self.health() health_response = await self.health()
if health_response["status"] == HealthStatus.ERROR: if health_response["status"] == HealthStatus.ERROR:
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal") logger.warning(
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
)
if self.config.refresh_models:
logger.debug("ollama starting background model refresh task")
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
if task.cancelled():
import traceback
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
logger.error(f"ollama background refresh task died: {task.exception()}")
else:
logger.error("ollama background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
# Wait for model store to be available (with timeout)
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
provider_id = self.__provider_id__
while True:
try:
response = await self.client.list()
except Exception as e:
logger.warning(f"Failed to list models: {str(e)}")
await asyncio.sleep(self.config.refresh_models_interval)
continue
models = []
for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
if model_type == ModelType.embedding:
continue
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
)
await self.model_store.update_registered_llm_models(provider_id, models)
logger.debug(f"ollama refreshed model list ({len(models)} models)")
await asyncio.sleep(self.config.refresh_models_interval)
async def health(self) -> HealthResponse: async def health(self) -> HealthResponse:
""" """
@ -157,7 +223,12 @@ class OllamaInferenceAdapter(
return available_models return available_models
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass if hasattr(self, "_refresh_task") and not self._refresh_task.done():
logger.debug("ollama cancelling background refresh task")
self._refresh_task.cancel()
self._client = None
self._openai_client = None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass

View file

@ -8,7 +8,7 @@ import logging
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any from typing import Any
from openai import AsyncOpenAI from openai import AsyncOpenAI, NotFoundError
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
@ -60,6 +60,27 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# litellm specific model names, an abstraction leak. # litellm specific model names, an abstraction leak.
self.is_openai_compat = True self.is_openai_compat = True
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from OpenAI.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
openai_client = self._get_openai_client()
retrieved_model = await openai_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from OpenAI")
return False
except Exception as e:
logger.error(f"Failed to check model availability from OpenAI: {e}")
return False
async def initialize(self) -> None: async def initialize(self) -> None:
await super().initialize() await super().initialize()

View file

@ -29,6 +29,14 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=True, default=True,
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.", description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
) )
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)
refresh_models_interval: int = Field(
default=300,
description="Interval in seconds to refresh models",
)
@field_validator("tls_verify") @field_validator("tls_verify")
@classmethod @classmethod
@ -46,7 +54,7 @@ class VLLMInferenceAdapterConfig(BaseModel):
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,
url: str = "${env.VLLM_URL}", url: str = "${env.VLLM_URL:=}",
**kwargs, **kwargs,
): ):
return { return {

View file

@ -3,8 +3,8 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import json import json
import logging
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
Message, Message,
ModelStore,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAICompletion, OpenAICompletion,
OpenAIEmbeddingData, OpenAIEmbeddingData,
@ -54,6 +55,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
@ -84,7 +86,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig from .config import VLLMInferenceAdapterConfig
log = logging.getLogger(__name__) log = get_logger(name=__name__, category="inference")
def build_hf_repo_model_entries(): def build_hf_repo_model_entries():
@ -288,16 +290,76 @@ async def _process_vllm_chat_completion_stream_response(
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
_refresh_task: asyncio.Task | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None: def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config self.config = config
self.client = None self.client = None
async def initialize(self) -> None: async def initialize(self) -> None:
pass if not self.config.url:
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
# or available in distributions like "starter" without causing a ruckus
return
if self.config.refresh_models:
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
import traceback
if task.cancelled():
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
# print the stack trace for the exception
exc = task.exception()
log.error(f"vLLM background refresh task died: {exc}")
traceback.print_exception(exc)
else:
log.error("vLLM background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
provider_id = self.__provider_id__
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
self._lazy_initialize_client()
assert self.client is not None # mypy
while True:
try:
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
)
await self.model_store.update_registered_llm_models(provider_id, models)
log.debug(f"vLLM refreshed model list ({len(models)} models)")
except Exception as e:
log.error(f"vLLM background refresh task failed: {e}")
await asyncio.sleep(self.config.refresh_models_interval)
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass if self._refresh_task:
self._refresh_task.cancel()
self._refresh_task = None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
@ -312,6 +374,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
HealthResponse: A dictionary containing the health status. HealthResponse: A dictionary containing the health status.
""" """
try: try:
if not self.config.url:
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")
client = self._create_client() if self.client is None else self.client client = self._create_client() if self.client is None else self.client
_ = [m async for m in client.models.list()] # Ensure the client is initialized _ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK) return HealthResponse(status=HealthStatus.OK)
@ -327,6 +392,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if self.client is not None: if self.client is not None:
return return
if not self.config.url:
raise ValueError(
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
)
log.info(f"Initializing vLLM client with base_url={self.config.url}") log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client() self.client = self._create_client()

View file

@ -217,7 +217,6 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
embedding_model: str | None = None, embedding_model: str | None = None,
embedding_dimension: int | None = 384, embedding_dimension: int | None = 384,
provider_id: str | None = None, provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject: ) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
uri: str = Field(description="The URI of the Milvus server") uri: str = Field(description="The URI of the Milvus server")
token: str | None = Field(description="The token of the Milvus server") token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong") consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) kvstore: KVStoreConfig = Field(description="Config for KV store backend")
# This configuration allows additional fields to be passed through to the underlying Milvus client. # This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. # See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"} return {
"uri": "${env.MILVUS_ENDPOINT}",
"token": "${env.MILVUS_TOKEN}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="milvus_remote_registry.db",
),
}

Some files were not shown because too many files have changed in this diff Show more