mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 20:42:24 +00:00
Merge branch 'main' into chore/strong-typing
This commit is contained in:
commit
16d6a7a22f
84 changed files with 3177 additions and 2793 deletions
|
|
@ -4,3 +4,9 @@ omit =
|
|||
*/llama_stack/providers/*
|
||||
*/llama_stack/templates/*
|
||||
.venv/*
|
||||
*/llama_stack/cli/scripts/*
|
||||
*/llama_stack/ui/*
|
||||
*/llama_stack/distribution/ui/*
|
||||
*/llama_stack/strong_typing/*
|
||||
*/llama_stack/env.py
|
||||
*/__init__.py
|
||||
|
|
|
|||
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
# These owners will be the default owners for everything in
|
||||
# the repo. Unless a later match takes precedence,
|
||||
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist @mattf
|
||||
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist @mattf @slekkala1
|
||||
|
|
|
|||
2
.github/actions/setup-ollama/action.yml
vendored
2
.github/actions/setup-ollama/action.yml
vendored
|
|
@ -7,7 +7,5 @@ runs:
|
|||
shell: bash
|
||||
run: |
|
||||
docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models
|
||||
# TODO: rebuild an ollama image with llama-guard3:1b
|
||||
echo "Verifying Ollama status..."
|
||||
timeout 30 bash -c 'while ! curl -s -L http://127.0.0.1:11434; do sleep 1 && echo "."; done'
|
||||
docker exec ollama ollama pull llama-guard3:1b
|
||||
|
|
|
|||
57
.github/workflows/coverage-badge.yml
vendored
Normal file
57
.github/workflows/coverage-badge.yml
vendored
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
name: Coverage Badge
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'llama_stack/**'
|
||||
- 'tests/unit/**'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/unit-tests.yml'
|
||||
- '.github/workflows/coverage-badge.yml' # This workflow
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
unit-tests:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install dependencies
|
||||
uses: ./.github/actions/setup-runner
|
||||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
./scripts/unit-tests.sh
|
||||
|
||||
- name: Coverage Badge
|
||||
uses: tj-actions/coverage-badge-py@1788babcb24544eb5bbb6e0d374df5d1e54e670f # v2.0.4
|
||||
|
||||
- name: Verify Changed files
|
||||
uses: tj-actions/verify-changed-files@a1c6acee9df209257a246f2cc6ae8cb6581c1edf # v20.0.4
|
||||
id: verify-changed-files
|
||||
with:
|
||||
files: coverage.svg
|
||||
|
||||
- name: Commit files
|
||||
if: steps.verify-changed-files.outputs.files_changed == 'true'
|
||||
run: |
|
||||
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config --local user.name "github-actions[bot]"
|
||||
git add coverage.svg
|
||||
git commit -m "Updated coverage.svg"
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.verify-changed-files.outputs.files_changed == 'true'
|
||||
uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
title: "ci: [Automatic] Coverage Badge Update"
|
||||
body: |
|
||||
This PR updates the coverage badge based on the latest coverage report.
|
||||
|
||||
Automatically generated by the [workflow coverage-badge.yaml](.github/workflows/coverage-badge.yaml)
|
||||
delete-branch: true
|
||||
355
.github/workflows/gha_workflow_llama_stack_tests.yml
vendored
355
.github/workflows/gha_workflow_llama_stack_tests.yml
vendored
|
|
@ -1,355 +0,0 @@
|
|||
name: "Run Llama-stack Tests"
|
||||
|
||||
on:
|
||||
#### Temporarily disable PR runs until tests run as intended within mainline.
|
||||
#TODO Add this back.
|
||||
#pull_request_target:
|
||||
# types: ["opened"]
|
||||
# branches:
|
||||
# - 'main'
|
||||
# paths:
|
||||
# - 'llama_stack/**/*.py'
|
||||
# - 'tests/**/*.py'
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
runner:
|
||||
description: 'GHA Runner Scale Set label to run workflow on.'
|
||||
required: true
|
||||
default: "llama-stack-gha-runner-gpu"
|
||||
|
||||
checkout_reference:
|
||||
description: "The branch, tag, or SHA to checkout"
|
||||
required: true
|
||||
default: "main"
|
||||
|
||||
debug:
|
||||
description: 'Run debugging steps?'
|
||||
required: false
|
||||
default: "true"
|
||||
|
||||
sleep_time:
|
||||
description: '[DEBUG] sleep time for debugging'
|
||||
required: true
|
||||
default: "0"
|
||||
|
||||
provider_id:
|
||||
description: 'ID of your provider'
|
||||
required: true
|
||||
default: "meta_reference"
|
||||
|
||||
model_id:
|
||||
description: 'Shorthand name for target model ID (llama_3b or llama_8b)'
|
||||
required: true
|
||||
default: "llama_3b"
|
||||
|
||||
model_override_3b:
|
||||
description: 'Specify shorthand model for <llama_3b> '
|
||||
required: false
|
||||
default: "Llama3.2-3B-Instruct"
|
||||
|
||||
model_override_8b:
|
||||
description: 'Specify shorthand model for <llama_8b> '
|
||||
required: false
|
||||
default: "Llama3.1-8B-Instruct"
|
||||
|
||||
env:
|
||||
# ID used for each test's provider config
|
||||
PROVIDER_ID: "${{ inputs.provider_id || 'meta_reference' }}"
|
||||
|
||||
# Path to model checkpoints within EFS volume
|
||||
MODEL_CHECKPOINT_DIR: "/data/llama"
|
||||
|
||||
# Path to directory to run tests from
|
||||
TESTS_PATH: "${{ github.workspace }}/llama_stack/providers/tests"
|
||||
|
||||
# Keep track of a list of model IDs that are valid to use within pytest fixture marks
|
||||
AVAILABLE_MODEL_IDs: "llama_3b llama_8b"
|
||||
|
||||
# Shorthand name for model ID, used in pytest fixture marks
|
||||
MODEL_ID: "${{ inputs.model_id || 'llama_3b' }}"
|
||||
|
||||
# Override the `llama_3b` / `llama_8b' models, else use the default.
|
||||
LLAMA_3B_OVERRIDE: "${{ inputs.model_override_3b || 'Llama3.2-3B-Instruct' }}"
|
||||
LLAMA_8B_OVERRIDE: "${{ inputs.model_override_8b || 'Llama3.1-8B-Instruct' }}"
|
||||
|
||||
# Defines which directories in TESTS_PATH to exclude from the test loop
|
||||
EXCLUDED_DIRS: "__pycache__"
|
||||
|
||||
# Defines the output xml reports generated after a test is run
|
||||
REPORTS_GEN: ""
|
||||
|
||||
jobs:
|
||||
execute_workflow:
|
||||
name: Execute workload on Self-Hosted GPU k8s runner
|
||||
permissions:
|
||||
pull-requests: write
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
runs-on: ${{ inputs.runner != '' && inputs.runner || 'llama-stack-gha-runner-gpu' }}
|
||||
if: always()
|
||||
steps:
|
||||
|
||||
##############################
|
||||
#### INITIAL DEBUG CHECKS ####
|
||||
##############################
|
||||
- name: "[DEBUG] Check content of the EFS mount"
|
||||
id: debug_efs_volume
|
||||
continue-on-error: true
|
||||
if: inputs.debug == 'true'
|
||||
run: |
|
||||
echo "========= Content of the EFS mount ============="
|
||||
ls -la ${{ env.MODEL_CHECKPOINT_DIR }}
|
||||
|
||||
- name: "[DEBUG] Get runner container OS information"
|
||||
id: debug_os_info
|
||||
if: ${{ inputs.debug == 'true' }}
|
||||
run: |
|
||||
cat /etc/os-release
|
||||
|
||||
- name: "[DEBUG] Print environment variables"
|
||||
id: debug_env_vars
|
||||
if: ${{ inputs.debug == 'true' }}
|
||||
run: |
|
||||
echo "PROVIDER_ID = ${PROVIDER_ID}"
|
||||
echo "MODEL_CHECKPOINT_DIR = ${MODEL_CHECKPOINT_DIR}"
|
||||
echo "AVAILABLE_MODEL_IDs = ${AVAILABLE_MODEL_IDs}"
|
||||
echo "MODEL_ID = ${MODEL_ID}"
|
||||
echo "LLAMA_3B_OVERRIDE = ${LLAMA_3B_OVERRIDE}"
|
||||
echo "LLAMA_8B_OVERRIDE = ${LLAMA_8B_OVERRIDE}"
|
||||
echo "EXCLUDED_DIRS = ${EXCLUDED_DIRS}"
|
||||
echo "REPORTS_GEN = ${REPORTS_GEN}"
|
||||
|
||||
############################
|
||||
#### MODEL INPUT CHECKS ####
|
||||
############################
|
||||
|
||||
- name: "Check if env.model_id is valid"
|
||||
id: check_model_id
|
||||
run: |
|
||||
if [[ " ${AVAILABLE_MODEL_IDs[@]} " =~ " ${MODEL_ID} " ]]; then
|
||||
echo "Model ID '${MODEL_ID}' is valid."
|
||||
else
|
||||
echo "Model ID '${MODEL_ID}' is invalid. Terminating workflow."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
#######################
|
||||
#### CODE CHECKOUT ####
|
||||
#######################
|
||||
- name: "Checkout 'meta-llama/llama-stack' repository"
|
||||
id: checkout_repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
ref: ${{ inputs.branch }}
|
||||
|
||||
- name: "[DEBUG] Content of the repository after checkout"
|
||||
id: debug_content_after_checkout
|
||||
if: ${{ inputs.debug == 'true' }}
|
||||
run: |
|
||||
ls -la ${GITHUB_WORKSPACE}
|
||||
|
||||
##########################################################
|
||||
#### OPTIONAL SLEEP DEBUG ####
|
||||
# #
|
||||
# Use to "exec" into the test k8s POD and run tests #
|
||||
# manually to identify what dependencies are being used. #
|
||||
# #
|
||||
##########################################################
|
||||
- name: "[DEBUG] sleep"
|
||||
id: debug_sleep
|
||||
if: ${{ inputs.debug == 'true' && inputs.sleep_time != '' }}
|
||||
run: |
|
||||
sleep ${{ inputs.sleep_time }}
|
||||
|
||||
############################
|
||||
#### UPDATE SYSTEM PATH ####
|
||||
############################
|
||||
- name: "Update path: execute"
|
||||
id: path_update_exec
|
||||
run: |
|
||||
# .local/bin is needed for certain libraries installed below to be recognized
|
||||
# when calling their executable to install sub-dependencies
|
||||
mkdir -p ${HOME}/.local/bin
|
||||
echo "${HOME}/.local/bin" >> "$GITHUB_PATH"
|
||||
|
||||
#####################################
|
||||
#### UPDATE CHECKPOINT DIRECTORY ####
|
||||
#####################################
|
||||
- name: "Update checkpoint directory"
|
||||
id: checkpoint_update
|
||||
run: |
|
||||
echo "Checkpoint directory: ${MODEL_CHECKPOINT_DIR}/$LLAMA_3B_OVERRIDE"
|
||||
if [ "${MODEL_ID}" = "llama_3b" ] && [ -d "${MODEL_CHECKPOINT_DIR}/${LLAMA_3B_OVERRIDE}" ]; then
|
||||
echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/${LLAMA_3B_OVERRIDE}" >> "$GITHUB_ENV"
|
||||
elif [ "${MODEL_ID}" = "llama_8b" ] && [ -d "${MODEL_CHECKPOINT_DIR}/${LLAMA_8B_OVERRIDE}" ]; then
|
||||
echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/${LLAMA_8B_OVERRIDE}" >> "$GITHUB_ENV"
|
||||
else
|
||||
echo "MODEL_ID & LLAMA_*B_OVERRIDE are not a valid pairing. Terminating workflow."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: "[DEBUG] Checkpoint update check"
|
||||
id: debug_checkpoint_update
|
||||
if: ${{ inputs.debug == 'true' }}
|
||||
run: |
|
||||
echo "MODEL_CHECKPOINT_DIR (after update) = ${MODEL_CHECKPOINT_DIR}"
|
||||
|
||||
##################################
|
||||
#### DEPENDENCY INSTALLATIONS ####
|
||||
##################################
|
||||
- name: "Installing 'apt' required packages"
|
||||
id: install_apt
|
||||
run: |
|
||||
echo "[STEP] Installing 'apt' required packages"
|
||||
sudo apt update -y
|
||||
sudo apt install -y python3 python3-pip npm wget
|
||||
|
||||
- name: "Installing packages with 'curl'"
|
||||
id: install_curl
|
||||
run: |
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
- name: "Installing packages with 'wget'"
|
||||
id: install_wget
|
||||
run: |
|
||||
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
||||
chmod +x Miniconda3-latest-Linux-x86_64.sh
|
||||
./Miniconda3-latest-Linux-x86_64.sh -b install -c pytorch -c nvidia faiss-gpu=1.9.0
|
||||
# Add miniconda3 bin to system path
|
||||
echo "${HOME}/miniconda3/bin" >> "$GITHUB_PATH"
|
||||
|
||||
- name: "Installing packages with 'npm'"
|
||||
id: install_npm_generic
|
||||
run: |
|
||||
sudo npm install -g junit-merge
|
||||
|
||||
- name: "Installing pip dependencies"
|
||||
id: install_pip_generic
|
||||
run: |
|
||||
echo "[STEP] Installing 'llama-stack' models"
|
||||
pip install -U pip setuptools
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
pip install -U \
|
||||
torch torchvision \
|
||||
pytest pytest_asyncio \
|
||||
fairscale lm-format-enforcer \
|
||||
zmq chardet pypdf \
|
||||
pandas sentence_transformers together \
|
||||
aiosqlite
|
||||
- name: "Installing packages with conda"
|
||||
id: install_conda_generic
|
||||
run: |
|
||||
conda install -q -c pytorch -c nvidia faiss-gpu=1.9.0
|
||||
|
||||
#############################################################
|
||||
#### TESTING TO BE DONE FOR BOTH PRS AND MANUAL DISPATCH ####
|
||||
#############################################################
|
||||
- name: "Run Tests: Loop"
|
||||
id: run_tests_loop
|
||||
working-directory: "${{ github.workspace }}"
|
||||
run: |
|
||||
pattern=""
|
||||
for dir in llama_stack/providers/tests/*; do
|
||||
if [ -d "$dir" ]; then
|
||||
dir_name=$(basename "$dir")
|
||||
if [[ ! " $EXCLUDED_DIRS " =~ " $dir_name " ]]; then
|
||||
for file in "$dir"/test_*.py; do
|
||||
test_name=$(basename "$file")
|
||||
new_file="result-${dir_name}-${test_name}.xml"
|
||||
if torchrun $(which pytest) -s -v ${TESTS_PATH}/${dir_name}/${test_name} -m "${PROVIDER_ID} and ${MODEL_ID}" \
|
||||
--junitxml="${{ github.workspace }}/${new_file}"; then
|
||||
echo "Ran test: ${test_name}"
|
||||
else
|
||||
echo "Did NOT run test: ${test_name}"
|
||||
fi
|
||||
pattern+="${new_file} "
|
||||
done
|
||||
fi
|
||||
fi
|
||||
done
|
||||
echo "REPORTS_GEN=$pattern" >> "$GITHUB_ENV"
|
||||
|
||||
- name: "Test Summary: Merge"
|
||||
id: test_summary_merge
|
||||
working-directory: "${{ github.workspace }}"
|
||||
run: |
|
||||
echo "Merging the following test result files: ${REPORTS_GEN}"
|
||||
# Defaults to merging them into 'merged-test-results.xml'
|
||||
junit-merge ${{ env.REPORTS_GEN }}
|
||||
|
||||
############################################
|
||||
#### AUTOMATIC TESTING ON PULL REQUESTS ####
|
||||
############################################
|
||||
|
||||
#### Run tests ####
|
||||
|
||||
- name: "PR - Run Tests"
|
||||
id: pr_run_tests
|
||||
working-directory: "${{ github.workspace }}"
|
||||
if: github.event_name == 'pull_request_target'
|
||||
run: |
|
||||
echo "[STEP] Running PyTest tests at 'GITHUB_WORKSPACE' path: ${GITHUB_WORKSPACE} | path: ${{ github.workspace }}"
|
||||
# (Optional) Add more tests here.
|
||||
|
||||
# Merge test results with 'merged-test-results.xml' from above.
|
||||
# junit-merge <new-test-results> merged-test-results.xml
|
||||
|
||||
#### Create test summary ####
|
||||
|
||||
- name: "PR - Test Summary"
|
||||
id: pr_test_summary_create
|
||||
if: github.event_name == 'pull_request_target'
|
||||
uses: test-summary/action@31493c76ec9e7aa675f1585d3ed6f1da69269a86 # v2.4
|
||||
with:
|
||||
paths: "${{ github.workspace }}/merged-test-results.xml"
|
||||
output: test-summary.md
|
||||
|
||||
- name: "PR - Upload Test Summary"
|
||||
id: pr_test_summary_upload
|
||||
if: github.event_name == 'pull_request_target'
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: test-summary
|
||||
path: test-summary.md
|
||||
|
||||
#### Update PR request ####
|
||||
|
||||
- name: "PR - Update comment"
|
||||
id: pr_update_comment
|
||||
if: github.event_name == 'pull_request_target'
|
||||
uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b # v3.0.1
|
||||
with:
|
||||
filePath: test-summary.md
|
||||
|
||||
########################
|
||||
#### MANUAL TESTING ####
|
||||
########################
|
||||
|
||||
#### Run tests ####
|
||||
|
||||
- name: "Manual - Run Tests: Prep"
|
||||
id: manual_run_tests
|
||||
working-directory: "${{ github.workspace }}"
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: |
|
||||
echo "[STEP] Running PyTest tests at 'GITHUB_WORKSPACE' path: ${{ github.workspace }}"
|
||||
|
||||
#TODO Use this when collection errors are resolved
|
||||
# pytest -s -v -m "${PROVIDER_ID} and ${MODEL_ID}" --junitxml="${{ github.workspace }}/merged-test-results.xml"
|
||||
|
||||
# (Optional) Add more tests here.
|
||||
|
||||
# Merge test results with 'merged-test-results.xml' from above.
|
||||
# junit-merge <new-test-results> merged-test-results.xml
|
||||
|
||||
#### Create test summary ####
|
||||
|
||||
- name: "Manual - Test Summary"
|
||||
id: manual_test_summary
|
||||
if: always() && github.event_name == 'workflow_dispatch'
|
||||
uses: test-summary/action@31493c76ec9e7aa675f1585d3ed6f1da69269a86 # v2.4
|
||||
with:
|
||||
paths: "${{ github.workspace }}/merged-test-results.xml"
|
||||
9
.github/workflows/integration-tests.yml
vendored
9
.github/workflows/integration-tests.yml
vendored
|
|
@ -7,11 +7,12 @@ on:
|
|||
branches: [ main ]
|
||||
paths:
|
||||
- 'llama_stack/**'
|
||||
- 'tests/integration/**'
|
||||
- 'tests/**'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/integration-tests.yml' # This workflow
|
||||
- '.github/actions/setup-ollama/action.yml'
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Daily at 12 AM UTC
|
||||
workflow_dispatch:
|
||||
|
|
@ -70,7 +71,7 @@ jobs:
|
|||
|
||||
- name: Build Llama Stack
|
||||
run: |
|
||||
uv run llama stack build --template starter --image-type venv
|
||||
uv run llama stack build --template ci-tests --image-type venv
|
||||
|
||||
- name: Check Storage and Memory Available Before Tests
|
||||
if: ${{ always() }}
|
||||
|
|
@ -91,9 +92,9 @@ jobs:
|
|||
shell: bash
|
||||
run: |
|
||||
if [ "${{ matrix.client-type }}" == "library" ]; then
|
||||
stack_config="starter"
|
||||
stack_config="ci-tests"
|
||||
else
|
||||
stack_config="server:starter"
|
||||
stack_config="server:ci-tests"
|
||||
fi
|
||||
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
|
||||
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ jobs:
|
|||
|
||||
- name: Build Llama Stack
|
||||
run: |
|
||||
uv run llama stack build --template starter --image-type venv
|
||||
uv run llama stack build --template ci-tests --image-type venv
|
||||
|
||||
- name: Check Storage and Memory Available Before Tests
|
||||
if: ${{ always() }}
|
||||
|
|
|
|||
10
.github/workflows/providers-build.yml
vendored
10
.github/workflows/providers-build.yml
vendored
|
|
@ -97,9 +97,9 @@ jobs:
|
|||
|
||||
- name: Build a single provider
|
||||
run: |
|
||||
yq -i '.image_type = "container"' llama_stack/templates/starter/build.yaml
|
||||
yq -i '.image_name = "test"' llama_stack/templates/starter/build.yaml
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/starter/build.yaml
|
||||
yq -i '.image_type = "container"' llama_stack/templates/ci-tests/build.yaml
|
||||
yq -i '.image_name = "test"' llama_stack/templates/ci-tests/build.yaml
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/ci-tests/build.yaml
|
||||
|
||||
- name: Inspect the container image entrypoint
|
||||
run: |
|
||||
|
|
@ -126,14 +126,14 @@ jobs:
|
|||
.image_type = "container" |
|
||||
.image_name = "ubi9-test" |
|
||||
.distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest"
|
||||
' llama_stack/templates/starter/build.yaml
|
||||
' llama_stack/templates/ci-tests/build.yaml
|
||||
|
||||
- name: Build dev container (UBI9)
|
||||
env:
|
||||
USE_COPY_NOT_MOUNT: "true"
|
||||
LLAMA_STACK_DIR: "."
|
||||
run: |
|
||||
uv run llama stack build --config llama_stack/templates/starter/build.yaml
|
||||
uv run llama stack build --config llama_stack/templates/ci-tests/build.yaml
|
||||
|
||||
- name: Inspect UBI9 image
|
||||
run: |
|
||||
|
|
|
|||
2
.github/workflows/python-build-test.yml
vendored
2
.github/workflows/python-build-test.yml
vendored
|
|
@ -20,7 +20,7 @@ jobs:
|
|||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1
|
||||
uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f # v6.4.1
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
activate-environment: true
|
||||
|
|
|
|||
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
|
@ -36,7 +36,7 @@ jobs:
|
|||
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --cov=llama_stack --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }}
|
||||
PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --junitxml=pytest-report-${{ matrix.python }}.xml
|
||||
|
||||
- name: Upload test results
|
||||
if: always()
|
||||
|
|
|
|||
|
|
@ -129,6 +129,22 @@ repos:
|
|||
require_serial: true
|
||||
always_run: true
|
||||
files: ^llama_stack/.*$
|
||||
- id: forbid-pytest-asyncio
|
||||
name: Block @pytest.mark.asyncio and @pytest_asyncio.fixture
|
||||
entry: bash
|
||||
language: system
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
grep -EnH '^[^#]*@pytest\.mark\.asyncio|@pytest_asyncio\.fixture' "$@" && {
|
||||
echo;
|
||||
echo "❌ Do not use @pytest.mark.asyncio or @pytest_asyncio.fixture."
|
||||
echo " pytest is already configured with async-mode=auto."
|
||||
echo;
|
||||
exit 1;
|
||||
} || true
|
||||
|
||||
ci:
|
||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ uv run pre-commit run --all-files
|
|||
|
||||
## Running tests
|
||||
|
||||
You can find the Llama Stack testing documentation here [here](tests/README.md).
|
||||
You can find the Llama Stack testing documentation [here](https://github.com/meta-llama/llama-stack/blob/main/tests/README.md).
|
||||
|
||||
## Adding a new dependency to the project
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
[](https://discord.gg/llama-stack)
|
||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
||||

|
||||
|
||||
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
|
||||
|
||||
|
|
|
|||
21
coverage.svg
Normal file
21
coverage.svg
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="99" height="20">
|
||||
<linearGradient id="b" x2="0" y2="100%">
|
||||
<stop offset="0" stop-color="#bbb" stop-opacity=".1"/>
|
||||
<stop offset="1" stop-opacity=".1"/>
|
||||
</linearGradient>
|
||||
<mask id="a">
|
||||
<rect width="99" height="20" rx="3" fill="#fff"/>
|
||||
</mask>
|
||||
<g mask="url(#a)">
|
||||
<path fill="#555" d="M0 0h63v20H0z"/>
|
||||
<path fill="#fe7d37" d="M63 0h36v20H63z"/>
|
||||
<path fill="url(#b)" d="M0 0h99v20H0z"/>
|
||||
</g>
|
||||
<g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="11">
|
||||
<text x="31.5" y="15" fill="#010101" fill-opacity=".3">coverage</text>
|
||||
<text x="31.5" y="14">coverage</text>
|
||||
<text x="80" y="15" fill="#010101" fill-opacity=".3">44%</text>
|
||||
<text x="80" y="14">44%</text>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 904 B |
29
docs/_static/llama-stack-spec.html
vendored
29
docs/_static/llama-stack-spec.html
vendored
|
|
@ -14470,28 +14470,31 @@
|
|||
"DPOAlignmentConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reward_scale": {
|
||||
"beta": {
|
||||
"type": "number"
|
||||
},
|
||||
"reward_clip": {
|
||||
"type": "number"
|
||||
},
|
||||
"epsilon": {
|
||||
"type": "number"
|
||||
},
|
||||
"gamma": {
|
||||
"type": "number"
|
||||
"loss_type": {
|
||||
"$ref": "#/components/schemas/DPOLossType",
|
||||
"default": "sigmoid"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"reward_scale",
|
||||
"reward_clip",
|
||||
"epsilon",
|
||||
"gamma"
|
||||
"beta",
|
||||
"loss_type"
|
||||
],
|
||||
"title": "DPOAlignmentConfig"
|
||||
},
|
||||
"DPOLossType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"sigmoid",
|
||||
"hinge",
|
||||
"ipo",
|
||||
"kto_pair"
|
||||
],
|
||||
"title": "DPOLossType"
|
||||
},
|
||||
"DataConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
|
|||
25
docs/_static/llama-stack-spec.yaml
vendored
25
docs/_static/llama-stack-spec.yaml
vendored
|
|
@ -10111,21 +10111,24 @@ components:
|
|||
DPOAlignmentConfig:
|
||||
type: object
|
||||
properties:
|
||||
reward_scale:
|
||||
type: number
|
||||
reward_clip:
|
||||
type: number
|
||||
epsilon:
|
||||
type: number
|
||||
gamma:
|
||||
beta:
|
||||
type: number
|
||||
loss_type:
|
||||
$ref: '#/components/schemas/DPOLossType'
|
||||
default: sigmoid
|
||||
additionalProperties: false
|
||||
required:
|
||||
- reward_scale
|
||||
- reward_clip
|
||||
- epsilon
|
||||
- gamma
|
||||
- beta
|
||||
- loss_type
|
||||
title: DPOAlignmentConfig
|
||||
DPOLossType:
|
||||
type: string
|
||||
enum:
|
||||
- sigmoid
|
||||
- hinge
|
||||
- ipo
|
||||
- kto_pair
|
||||
title: DPOLossType
|
||||
DataConfig:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ Llama Stack allows you to build different layers of distributions for your AI wo
|
|||
|
||||
Building production AI applications today requires solving multiple challenges:
|
||||
|
||||
Infrastructure Complexity
|
||||
**Infrastructure Complexity**
|
||||
|
||||
- Running large language models efficiently requires specialized infrastructure.
|
||||
- Different deployment scenarios (local development, cloud, edge) need different solutions.
|
||||
|
|
|
|||
|
|
@ -222,10 +222,21 @@ llama-stack-client --endpoint http://localhost:5000 inference chat-completion --
|
|||
|
||||
## Deploying Llama Stack Server in AWS EKS
|
||||
|
||||
We've also provided a script to deploy the Llama Stack server in an AWS EKS cluster. Once you have an [EKS cluster](https://docs.aws.amazon.com/eks/latest/userguide/getting-started.html), you can run the following script to deploy the Llama Stack server.
|
||||
We've also provided a script to deploy the Llama Stack server in an AWS EKS cluster.
|
||||
|
||||
Prerequisites:
|
||||
- Set up an [EKS cluster](https://docs.aws.amazon.com/eks/latest/userguide/getting-started.html).
|
||||
- Create a [Github OAuth app](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and get the client ID and client secret.
|
||||
- Set the `Authorization callback URL` to `http://<your-llama-stack-ui-url>/api/auth/callback/`
|
||||
|
||||
|
||||
Run the following script to deploy the Llama Stack server:
|
||||
```
|
||||
export HF_TOKEN=<your-huggingface-token>
|
||||
export GITHUB_CLIENT_ID=<your-github-client-id>
|
||||
export GITHUB_CLIENT_SECRET=<your-github-client-secret>
|
||||
export LLAMA_STACK_UI_URL=<your-llama-stack-ui-url>
|
||||
|
||||
cd docs/source/distributions/eks
|
||||
./apply.sh
|
||||
```
|
||||
|
|
|
|||
|
|
@ -21,6 +21,24 @@ else
|
|||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "${GITHUB_CLIENT_ID:-}" ]; then
|
||||
echo "ERROR: GITHUB_CLIENT_ID not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "${GITHUB_CLIENT_SECRET:-}" ]; then
|
||||
echo "ERROR: GITHUB_CLIENT_SECRET not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "${LLAMA_STACK_UI_URL:-}" ]; then
|
||||
echo "ERROR: LLAMA_STACK_UI_URL not set. Should be set to the external URL of the UI (excluding port). You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
|
||||
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
||||
|
|
|
|||
|
|
@ -122,6 +122,9 @@ data:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
auth:
|
||||
provider_config:
|
||||
type: github_token
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
creationTimestamp: null
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ spec:
|
|||
spec:
|
||||
containers:
|
||||
- name: llama-stack
|
||||
image: llamastack/distribution-remote-vllm:latest
|
||||
image: llamastack/distribution-starter:latest
|
||||
imagePullPolicy: Always # since we have specified latest instead of a version
|
||||
env:
|
||||
- name: ENABLE_CHROMADB
|
||||
|
|
|
|||
|
|
@ -119,3 +119,6 @@ tool_groups:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
auth:
|
||||
provider_config:
|
||||
type: github_token
|
||||
|
|
|
|||
|
|
@ -26,6 +26,12 @@ spec:
|
|||
value: "http://llama-stack-service:8321"
|
||||
- name: LLAMA_STACK_UI_PORT
|
||||
value: "8322"
|
||||
- name: GITHUB_CLIENT_ID
|
||||
value: "${GITHUB_CLIENT_ID}"
|
||||
- name: GITHUB_CLIENT_SECRET
|
||||
value: "${GITHUB_CLIENT_SECRET}"
|
||||
- name: NEXTAUTH_URL
|
||||
value: "${LLAMA_STACK_UI_URL}:8322"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ When using the `:` pattern (like `${env.OLLAMA_INFERENCE_MODEL:__disabled__}`),
|
|||
|
||||
## Running the Distribution
|
||||
|
||||
You can run the starter distribution via Docker or Conda.
|
||||
You can run the starter distribution via Docker, Conda, or venv.
|
||||
|
||||
### Via Docker
|
||||
|
||||
|
|
@ -186,17 +186,12 @@ docker run \
|
|||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
### Via Conda or venv
|
||||
|
||||
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
|
||||
Ensure you have configured the starter distribution using the environment variables explained above.
|
||||
|
||||
```bash
|
||||
llama stack build --template starter --image-type conda
|
||||
llama stack run distributions/starter/run.yaml \
|
||||
--port 8321 \
|
||||
--env OPENAI_API_KEY=your_openai_key \
|
||||
--env FIREWORKS_API_KEY=your_fireworks_key \
|
||||
--env TOGETHER_API_KEY=your_together_key
|
||||
uv run --with llama-stack llama stack build --template starter --image-type <conda|venv> --run
|
||||
```
|
||||
|
||||
## Example Usage
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ ollama run llama3.2:3b --keepalive 60m
|
|||
#### Step 2: Run the Llama Stack server
|
||||
We will use `uv` to run the Llama Stack server.
|
||||
```bash
|
||||
INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
|
||||
ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
|
||||
```
|
||||
#### Step 3: Run the demo
|
||||
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
|
||||
|
|
@ -111,6 +111,12 @@ Ultimately, great work is about making a meaningful contribution and leaving a l
|
|||
```
|
||||
Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳
|
||||
|
||||
```{admonition} HuggingFace access
|
||||
:class: tip
|
||||
|
||||
If you are getting a **401 Client Error** from HuggingFace for the **all-MiniLM-L6-v2** model, try setting **HF_TOKEN** to a valid HuggingFace token in your environment
|
||||
```
|
||||
|
||||
### Next Steps
|
||||
|
||||
Now you're ready to dive deeper into Llama Stack!
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ This section contains documentation for all available providers for the **infere
|
|||
|
||||
- [inline::meta-reference](inline_meta-reference.md)
|
||||
- [inline::sentence-transformers](inline_sentence-transformers.md)
|
||||
- [inline::vllm](inline_vllm.md)
|
||||
- [remote::anthropic](remote_anthropic.md)
|
||||
- [remote::bedrock](remote_bedrock.md)
|
||||
- [remote::cerebras](remote_cerebras.md)
|
||||
|
|
|
|||
|
|
@ -1,29 +0,0 @@
|
|||
# inline::vllm
|
||||
|
||||
## Description
|
||||
|
||||
vLLM inference provider for high-performance model serving with PagedAttention and continuous batching.
|
||||
|
||||
## Configuration
|
||||
|
||||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `tensor_parallel_size` | `<class 'int'>` | No | 1 | Number of tensor parallel replicas (number of GPUs to use). |
|
||||
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
|
||||
| `max_model_len` | `<class 'int'>` | No | 4096 | Maximum context length to use during serving. |
|
||||
| `max_num_seqs` | `<class 'int'>` | No | 4 | Maximum parallel batch size for generation. |
|
||||
| `enforce_eager` | `<class 'bool'>` | No | False | Whether to use eager mode for inference (otherwise cuda graphs are used). |
|
||||
| `gpu_memory_utilization` | `<class 'float'>` | No | 0.3 | How much GPU memory will be allocated when this provider has finished loading, including memory that was already allocated before loading. |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
tensor_parallel_size: ${env.TENSOR_PARALLEL_SIZE:=1}
|
||||
max_tokens: ${env.MAX_TOKENS:=4096}
|
||||
max_model_len: ${env.MAX_MODEL_LEN:=4096}
|
||||
max_num_seqs: ${env.MAX_NUM_SEQS:=4}
|
||||
enforce_eager: ${env.ENFORCE_EAGER:=False}
|
||||
gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:=0.3}
|
||||
|
||||
```
|
||||
|
||||
|
|
@ -9,6 +9,8 @@ Ollama inference provider for running local models through the Ollama runtime.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | refresh and re-register models periodically |
|
||||
| `refresh_models_interval` | `<class 'int'>` | No | 300 | interval in seconds to refresh models |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
|||
|
|
@ -12,11 +12,13 @@ Remote vLLM inference provider for connecting to vLLM servers.
|
|||
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
|
||||
| `api_token` | `str \| None` | No | fake | The API token |
|
||||
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
||||
| `refresh_models_interval` | `<class 'int'>` | No | 300 | Interval in seconds to refresh models |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
url: ${env.VLLM_URL}
|
||||
url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
|
|
|
|||
|
|
@ -819,6 +819,12 @@ class OpenAIEmbeddingsResponse(BaseModel):
|
|||
class ModelStore(Protocol):
|
||||
async def get_model(self, identifier: str) -> Model: ...
|
||||
|
||||
async def update_registered_llm_models(
|
||||
self,
|
||||
provider_id: str,
|
||||
models: list[Model],
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class TextTruncation(Enum):
|
||||
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
from enum import StrEnum
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
|
@ -36,13 +36,21 @@ class Model(CommonModelFields, Resource):
|
|||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_model_id(self) -> str | None:
|
||||
def provider_model_id(self) -> str:
|
||||
assert self.provider_resource_id is not None, "Provider resource ID must be set"
|
||||
return self.provider_resource_id
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
model_type: ModelType = Field(default=ModelType.llm)
|
||||
|
||||
@field_validator("provider_resource_id")
|
||||
@classmethod
|
||||
def validate_provider_resource_id(cls, v):
|
||||
if v is None:
|
||||
raise ValueError("provider_resource_id cannot be None")
|
||||
return v
|
||||
|
||||
|
||||
class ModelInput(CommonModelFields):
|
||||
model_id: str
|
||||
|
|
|
|||
|
|
@ -104,12 +104,18 @@ class RLHFAlgorithm(Enum):
|
|||
dpo = "dpo"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DPOLossType(Enum):
|
||||
sigmoid = "sigmoid"
|
||||
hinge = "hinge"
|
||||
ipo = "ipo"
|
||||
kto_pair = "kto_pair"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DPOAlignmentConfig(BaseModel):
|
||||
reward_scale: float
|
||||
reward_clip: float
|
||||
epsilon: float
|
||||
gamma: float
|
||||
beta: float
|
||||
loss_type: DPOLossType = DPOLossType.sigmoid
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -47,8 +47,7 @@ class StackRun(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"--image-name",
|
||||
type=str,
|
||||
default=os.environ.get("CONDA_DEFAULT_ENV"),
|
||||
help="Name of the image to run. Defaults to the current environment",
|
||||
help="Name of the image to run.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--env",
|
||||
|
|
|
|||
|
|
@ -12,11 +12,13 @@ import os
|
|||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import Response as FastAPIResponse
|
||||
from llama_stack_client import (
|
||||
NOT_GIVEN,
|
||||
APIResponse,
|
||||
|
|
@ -112,6 +114,27 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
|||
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
|
||||
|
||||
|
||||
class LibraryClientUploadFile:
|
||||
"""LibraryClient UploadFile object that mimics FastAPI's UploadFile interface."""
|
||||
|
||||
def __init__(self, filename: str, content: bytes):
|
||||
self.filename = filename
|
||||
self.content = content
|
||||
self.content_type = "application/octet-stream"
|
||||
|
||||
async def read(self) -> bytes:
|
||||
return self.content
|
||||
|
||||
|
||||
class LibraryClientHttpxResponse:
|
||||
"""LibraryClient httpx Response object for FastAPI Response conversion."""
|
||||
|
||||
def __init__(self, response):
|
||||
self.content = response.body if isinstance(response.body, bytes) else response.body.encode()
|
||||
self.status_code = response.status_code
|
||||
self.headers = response.headers
|
||||
|
||||
|
||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -128,6 +151,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
self.skip_logger_removal = skip_logger_removal
|
||||
self.provider_data = provider_data
|
||||
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
def initialize(self):
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
|
@ -136,7 +161,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
if not self.skip_logger_removal:
|
||||
self._remove_root_logger_handlers()
|
||||
|
||||
return asyncio.run(self.async_client.initialize())
|
||||
return self.loop.run_until_complete(self.async_client.initialize())
|
||||
|
||||
def _remove_root_logger_handlers(self):
|
||||
"""
|
||||
|
|
@ -149,10 +174,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
# NOTE: We are using AsyncLlamaStackClient under the hood
|
||||
# A new event loop is needed to convert the AsyncStream
|
||||
# from async client into SyncStream return type for streaming
|
||||
loop = asyncio.new_event_loop()
|
||||
loop = self.loop
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if kwargs.get("stream"):
|
||||
|
|
@ -169,7 +191,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
loop.close()
|
||||
|
||||
return sync_generator()
|
||||
else:
|
||||
|
|
@ -179,7 +200,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
loop.close()
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -295,6 +315,31 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return response
|
||||
|
||||
def _handle_file_uploads(self, options: Any, body: dict) -> tuple[dict, list[str]]:
|
||||
"""Handle file uploads from OpenAI client and add them to the request body."""
|
||||
if not (hasattr(options, "files") and options.files):
|
||||
return body, []
|
||||
|
||||
if not isinstance(options.files, list):
|
||||
return body, []
|
||||
|
||||
field_names = []
|
||||
for file_tuple in options.files:
|
||||
if not (isinstance(file_tuple, tuple) and len(file_tuple) >= 2):
|
||||
continue
|
||||
|
||||
field_name = file_tuple[0]
|
||||
file_object = file_tuple[1]
|
||||
|
||||
if isinstance(file_object, BytesIO):
|
||||
file_object.seek(0)
|
||||
file_content = file_object.read()
|
||||
filename = getattr(file_object, "name", "uploaded_file")
|
||||
field_names.append(field_name)
|
||||
body[field_name] = LibraryClientUploadFile(filename, file_content)
|
||||
|
||||
return body, field_names
|
||||
|
||||
async def _call_non_streaming(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -310,15 +355,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
body = self._convert_body(path, options.method, body)
|
||||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
try:
|
||||
result = await matched_func(**body)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
# Handle FastAPI Response objects (e.g., from file content retrieval)
|
||||
if isinstance(result, FastAPIResponse):
|
||||
return LibraryClientHttpxResponse(result)
|
||||
|
||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
|
||||
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
content=json_content.encode("utf-8"),
|
||||
|
|
@ -330,7 +383,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
url=options.url,
|
||||
params=options.params,
|
||||
headers=options.headers or {},
|
||||
json=convert_pydantic_to_json_value(body),
|
||||
json=convert_pydantic_to_json_value(filtered_body),
|
||||
),
|
||||
)
|
||||
response = APIResponse(
|
||||
|
|
@ -405,13 +458,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict:
|
||||
def _convert_body(
|
||||
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
|
||||
) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
if self.route_impls is None:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
exclude_params = exclude_params or set()
|
||||
|
||||
func, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
sig = inspect.signature(func)
|
||||
|
||||
|
|
@ -423,6 +480,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
for param_name, param in sig.parameters.items():
|
||||
if param_name in body:
|
||||
value = body.get(param_name)
|
||||
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||
if param_name in exclude_params:
|
||||
converted_body[param_name] = value
|
||||
else:
|
||||
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
|
||||
|
||||
return converted_body
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ def validate_and_prepare_providers(
|
|||
specs = {}
|
||||
for provider in providers:
|
||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
continue
|
||||
|
||||
validate_provider(provider, api, provider_registry)
|
||||
|
|
|
|||
|
|
@ -80,3 +80,38 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
if existing_model is None:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
await self.unregister_object(existing_model)
|
||||
|
||||
async def update_registered_llm_models(
|
||||
self,
|
||||
provider_id: str,
|
||||
models: list[Model],
|
||||
) -> None:
|
||||
existing_models = await self.get_all_with_type("model")
|
||||
|
||||
# we may have an alias for the model registered by the user (or during initialization
|
||||
# from run.yaml) that we need to keep track of
|
||||
model_ids = {}
|
||||
for model in existing_models:
|
||||
# we leave embeddings models alone because often we don't get metadata
|
||||
# (embedding dimension, etc.) from the provider
|
||||
if model.provider_id == provider_id and model.model_type == ModelType.llm:
|
||||
model_ids[model.provider_resource_id] = model.identifier
|
||||
logger.debug(f"unregistering model {model.identifier}")
|
||||
await self.unregister_object(model)
|
||||
|
||||
for model in models:
|
||||
if model.model_type != ModelType.llm:
|
||||
continue
|
||||
if model.provider_resource_id in model_ids:
|
||||
model.identifier = model_ids[model.provider_resource_id]
|
||||
|
||||
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||
await self.register_object(
|
||||
ModelWithOwner(
|
||||
identifier=model.identifier,
|
||||
provider_resource_id=model.provider_resource_id,
|
||||
provider_id=provider_id,
|
||||
metadata=model.metadata,
|
||||
model_type=model.model_type,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -445,9 +445,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
# now that the logger is initialized, print the line about which type of config we are using.
|
||||
logger.info(log_line)
|
||||
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(config.model_dump(mode="json"))
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
_log_run_config(run_config=config)
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
|
|
@ -455,6 +453,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
)
|
||||
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
|
|
@ -493,7 +492,13 @@ def main(args: argparse.Namespace | None = None):
|
|||
)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
|
@ -591,7 +596,16 @@ def main(args: argparse.Namespace | None = None):
|
|||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
uvicorn.run(**uvicorn_config)
|
||||
# Run uvicorn in the existing event loop to preserve background tasks
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
"""Logs the run config with redacted fields and disabled providers removed."""
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
|
||||
clean_config = remove_disabled_providers(safe_config)
|
||||
logger.info(yaml.dump(clean_config, indent=2))
|
||||
|
||||
|
||||
def extract_path_params(route: str) -> list[str]:
|
||||
|
|
@ -602,5 +616,20 @@ def extract_path_params(route: str) -> list[str]:
|
|||
return params
|
||||
|
||||
|
||||
def remove_disabled_providers(obj):
|
||||
if isinstance(obj, dict):
|
||||
if (
|
||||
obj.get("provider_id") == "__disabled__"
|
||||
or obj.get("shield_id") == "__disabled__"
|
||||
or obj.get("provider_model_id") == "__disabled__"
|
||||
):
|
||||
return None
|
||||
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -172,7 +172,6 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
# Create a copy with resolved provider_id but original config
|
||||
disabled_provider = v.copy()
|
||||
disabled_provider["provider_id"] = resolved_provider_id
|
||||
result.append(disabled_provider)
|
||||
continue
|
||||
except EnvVarError:
|
||||
# If we can't resolve the provider_id, continue with normal processing
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import io
|
|||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
|
|
@ -184,16 +185,26 @@ class ChatFormat:
|
|||
content = content[: -len("<|eom_id|>")]
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
tool_name = None
|
||||
tool_arguments = {}
|
||||
tool_name: str | BuiltinTool | None = None
|
||||
tool_arguments: dict[str, Any] = {}
|
||||
|
||||
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
||||
if custom_tool_info is not None:
|
||||
tool_name, tool_arguments = custom_tool_info
|
||||
# Type guard: ensure custom_tool_info is a tuple of correct types
|
||||
if isinstance(custom_tool_info, tuple) and len(custom_tool_info) == 2:
|
||||
extracted_tool_name, extracted_tool_arguments = custom_tool_info
|
||||
# Handle both dict and str return types from the function
|
||||
if isinstance(extracted_tool_arguments, dict):
|
||||
tool_name, tool_arguments = extracted_tool_name, extracted_tool_arguments
|
||||
else:
|
||||
# If it's a string, treat it as a query parameter
|
||||
tool_name, tool_arguments = extracted_tool_name, {"query": extracted_tool_arguments}
|
||||
else:
|
||||
tool_name, tool_arguments = None, {}
|
||||
# Sometimes when agent has custom tools alongside builin tools
|
||||
# Agent responds for builtin tool calls in the format of the custom tools
|
||||
# This code tries to handle that case
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
if tool_name is not None and tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
if isinstance(tool_arguments, dict):
|
||||
tool_arguments = {
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.datatypes import AccessRule, Api
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
from .files import LocalfsFilesImpl
|
||||
|
|
@ -14,7 +14,7 @@ from .files import LocalfsFilesImpl
|
|||
__all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any]):
|
||||
impl = LocalfsFilesImpl(config)
|
||||
async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||
impl = LocalfsFilesImpl(config, policy)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -19,16 +19,19 @@ from llama_stack.apis.files import (
|
|||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import AccessRule
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
|
||||
|
||||
class LocalfsFilesImpl(Files):
|
||||
def __init__(self, config: LocalfsFilesImplConfig) -> None:
|
||||
def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None:
|
||||
self.config = config
|
||||
self.sql_store: SqlStore | None = None
|
||||
self.policy = policy
|
||||
self.sql_store: AuthorizedSqlStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the files provider by setting up storage directory and metadata database."""
|
||||
|
|
@ -37,7 +40,7 @@ class LocalfsFilesImpl(Files):
|
|||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize SQL store for metadata
|
||||
self.sql_store = sqlstore_impl(self.config.metadata_store)
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
|
||||
await self.sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
|
@ -51,6 +54,9 @@ class LocalfsFilesImpl(Files):
|
|||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _generate_file_id(self) -> str:
|
||||
"""Generate a unique file ID for OpenAI API."""
|
||||
return f"file-{uuid.uuid4().hex}"
|
||||
|
|
@ -123,6 +129,7 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
|
@ -153,7 +160,7 @@ class LocalfsFilesImpl(Files):
|
|||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
|
|
@ -171,7 +178,7 @@ class LocalfsFilesImpl(Files):
|
|||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
|
|
@ -194,7 +201,7 @@ class LocalfsFilesImpl(Files):
|
|||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
# Get file metadata
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
if not row:
|
||||
raise ValueError(f"File with id {file_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ class ProcessingMessageWrapper(BaseModel):
|
|||
|
||||
|
||||
def mp_rank_0() -> bool:
|
||||
return get_model_parallel_rank() == 0
|
||||
return bool(get_model_parallel_rank() == 0)
|
||||
|
||||
|
||||
def encode_msg(msg: ProcessingMessage) -> bytes:
|
||||
|
|
@ -125,7 +125,7 @@ def retrieve_requests(reply_socket_url: str):
|
|||
reply_socket.send_multipart([client_id, encode_msg(obj)])
|
||||
|
||||
while True:
|
||||
tasks = [None]
|
||||
tasks: list[ProcessingMessage | None] = [None]
|
||||
if mp_rank_0():
|
||||
client_id, maybe_task_json = maybe_get_work(reply_socket)
|
||||
if maybe_task_json is not None:
|
||||
|
|
@ -152,7 +152,7 @@ def retrieve_requests(reply_socket_url: str):
|
|||
break
|
||||
|
||||
for obj in out:
|
||||
updates = [None]
|
||||
updates: list[ProcessingMessage | None] = [None]
|
||||
if mp_rank_0():
|
||||
_, update_json = maybe_get_work(reply_socket)
|
||||
update = maybe_parse_message(update_json)
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]):
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VLLMConfig(BaseModel):
|
||||
"""Configuration for the vLLM inference provider.
|
||||
|
||||
Note that the model name is no longer part of this static configuration.
|
||||
You can bind an instance of this provider to a specific model with the
|
||||
``models.register()`` API call."""
|
||||
|
||||
tensor_parallel_size: int = Field(
|
||||
default=1,
|
||||
description="Number of tensor parallel replicas (number of GPUs to use).",
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
max_model_len: int = Field(default=4096, description="Maximum context length to use during serving.")
|
||||
max_num_seqs: int = Field(default=4, description="Maximum parallel batch size for generation.")
|
||||
enforce_eager: bool = Field(
|
||||
default=False,
|
||||
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
|
||||
)
|
||||
gpu_memory_utilization: float = Field(
|
||||
default=0.3,
|
||||
description=(
|
||||
"How much GPU memory will be allocated when this provider has finished "
|
||||
"loading, including memory that was already allocated before loading."
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:=1}",
|
||||
"max_tokens": "${env.MAX_TOKENS:=4096}",
|
||||
"max_model_len": "${env.MAX_MODEL_LEN:=4096}",
|
||||
"max_num_seqs": "${env.MAX_NUM_SEQS:=4}",
|
||||
"enforce_eager": "${env.ENFORCE_EAGER:=False}",
|
||||
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:=0.3}",
|
||||
}
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import vllm
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
GrammarResponseFormat,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
)
|
||||
|
||||
###############################################################################
|
||||
# This file contains OpenAI compatibility code that is currently only used
|
||||
# by the inline vLLM connector. Some or all of this code may be moved to a
|
||||
# central location at a later date.
|
||||
|
||||
|
||||
def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
||||
"""
|
||||
Merge the ``context`` field of a Llama Stack ``Message`` object into
|
||||
the content field for compabilitiy with OpenAI-style APIs.
|
||||
|
||||
Generates a content string that emulates the current behavior
|
||||
of ``llama_models.llama3.api.chat_format.encode_message()``.
|
||||
|
||||
:param message: Message that may include ``context`` field
|
||||
|
||||
:returns: A version of ``message`` with any context merged into the
|
||||
``content`` field.
|
||||
"""
|
||||
if not isinstance(message, UserMessage): # Separate type check for linter
|
||||
return message
|
||||
if message.context is None:
|
||||
return message
|
||||
return UserMessage(
|
||||
role=message.role,
|
||||
# Emumate llama_models.llama3.api.chat_format.encode_message()
|
||||
content=message.content + "\n\n" + message.context,
|
||||
context=None,
|
||||
)
|
||||
|
||||
|
||||
def _llama_stack_tools_to_openai_tools(
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
) -> list[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
||||
"""
|
||||
Convert the list of available tools from Llama Stack's format to vLLM's
|
||||
version of OpenAI's format.
|
||||
"""
|
||||
if tools is None:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for t in tools:
|
||||
if isinstance(t.tool_name, BuiltinTool):
|
||||
raise NotImplementedError("Built-in tools not yet implemented")
|
||||
if t.parameters is None:
|
||||
parameters = None
|
||||
else: # if t.parameters is not None
|
||||
# Convert the "required" flags to a list of required params
|
||||
required_params = [k for k, v in t.parameters.items() if v.required]
|
||||
parameters = {
|
||||
"type": "object", # Mystery value that shows up in OpenAI docs
|
||||
"properties": {
|
||||
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
|
||||
},
|
||||
"required": required_params,
|
||||
}
|
||||
|
||||
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
|
||||
name=t.tool_name, description=t.description, parameters=parameters
|
||||
)
|
||||
|
||||
# Every tool definition is double-boxed in a ChatCompletionToolsParam
|
||||
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
|
||||
return result
|
||||
|
||||
|
||||
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||
request: ChatCompletionRequest,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a chat completion request in Llama Stack format into an
|
||||
equivalent set of arguments to pass to an OpenAI-compatible
|
||||
chat completions API.
|
||||
|
||||
:param request: Bundled request parameters in Llama Stack format.
|
||||
|
||||
:returns: Dictionary of key-value pairs to use as an initializer
|
||||
for a dataclass or to be converted directly to JSON and sent
|
||||
over the wire.
|
||||
"""
|
||||
|
||||
converted_messages = [
|
||||
# This mystery async call makes the parent function also be async
|
||||
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True)
|
||||
for m in request.messages
|
||||
]
|
||||
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
|
||||
|
||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||
# tool choice unless at least one tool is enabled.
|
||||
converted_tool_choice = "none"
|
||||
if (
|
||||
request.tool_config is not None
|
||||
and request.tool_config.tool_choice == ToolChoice.auto
|
||||
and request.tools is not None
|
||||
and len(request.tools) > 0
|
||||
):
|
||||
converted_tool_choice = "auto"
|
||||
|
||||
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||
# Other connectors appear to drop it quietly.
|
||||
|
||||
# Use Llama Stack shared code to translate sampling parameters.
|
||||
sampling_options = get_sampling_options(request.sampling_params)
|
||||
|
||||
# get_sampling_options() translates repetition penalties to an option that
|
||||
# OpenAI's APIs don't know about.
|
||||
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
|
||||
# For now, translate repetition penalties into a format that vLLM's broken
|
||||
# API will handle correctly. Two wrongs make a right...
|
||||
if "repeat_penalty" in sampling_options:
|
||||
del sampling_options["repeat_penalty"]
|
||||
if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0:
|
||||
sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
# Convert a single response format into four different parameters, per
|
||||
# the OpenAI spec
|
||||
guided_decoding_options = dict()
|
||||
if request.response_format is None:
|
||||
# Use defaults
|
||||
pass
|
||||
elif isinstance(request.response_format, JsonSchemaResponseFormat):
|
||||
guided_decoding_options["guided_json"] = request.response_format.json_schema
|
||||
elif isinstance(request.response_format, GrammarResponseFormat):
|
||||
guided_decoding_options["guided_grammar"] = request.response_format.bnf
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'")
|
||||
|
||||
logprob_options = dict()
|
||||
if request.logprobs is not None:
|
||||
logprob_options["logprobs"] = request.logprobs.top_k
|
||||
|
||||
# Marshall together all the arguments for a ChatCompletionRequest
|
||||
request_options = {
|
||||
"model": request.model,
|
||||
"messages": converted_messages,
|
||||
"tools": converted_tools,
|
||||
"tool_choice": converted_tool_choice,
|
||||
"stream": request.stream,
|
||||
**sampling_options,
|
||||
**guided_decoding_options,
|
||||
**logprob_options,
|
||||
}
|
||||
|
||||
return request_options
|
||||
|
|
@ -1,811 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
|
||||
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
||||
# fully-qualified names
|
||||
import vllm.entrypoints.openai.protocol
|
||||
import vllm.sampling_params
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama import sku_list
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_stop_reason,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import VLLMConfig
|
||||
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
|
||||
|
||||
# Map from Hugging Face model architecture name to appropriate tool parser.
|
||||
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
|
||||
# available parsers.
|
||||
# TODO: Expand this list
|
||||
CONFIG_TYPE_TO_TOOL_PARSER = {
|
||||
"GraniteConfig": "granite",
|
||||
"MllamaConfig": "llama3_json",
|
||||
"LlamaConfig": "llama3_json",
|
||||
}
|
||||
DEFAULT_TOOL_PARSER = "pythonic"
|
||||
|
||||
|
||||
logger = get_logger(__name__, category="inference")
|
||||
|
||||
|
||||
def _random_uuid_str() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def _response_format_to_guided_decoding_params(
|
||||
response_format: ResponseFormat | None, # type: ignore
|
||||
) -> vllm.sampling_params.GuidedDecodingParams:
|
||||
"""
|
||||
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
|
||||
|
||||
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
|
||||
indicating no constraints.
|
||||
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
|
||||
"""
|
||||
if response_format is None:
|
||||
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
|
||||
# value that crashes the executor on some code paths. Use ``None`` instead.
|
||||
return None
|
||||
|
||||
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
|
||||
# Translate the types that exist and detect if Llama Stack adds new ones.
|
||||
if isinstance(response_format, JsonSchemaResponseFormat):
|
||||
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
|
||||
elif isinstance(response_format, GrammarResponseFormat):
|
||||
# BNF grammar.
|
||||
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
|
||||
# representation of the grammar.
|
||||
raise TypeError(
|
||||
"Constrained decoding with BNF grammars is not currently implemented, because the "
|
||||
"reference implementation does not implement it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
|
||||
|
||||
|
||||
def _convert_sampling_params(
|
||||
sampling_params: SamplingParams | None,
|
||||
response_format: ResponseFormat | None, # type: ignore
|
||||
log_prob_config: LogProbConfig | None,
|
||||
) -> vllm.SamplingParams:
|
||||
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
|
||||
format."""
|
||||
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
|
||||
# Stack dataclasses. These defaults are different from vLLM's defaults.
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if log_prob_config is None:
|
||||
log_prob_config = LogProbConfig()
|
||||
|
||||
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
|
||||
if sampling_params.strategy.top_k == 0:
|
||||
# vLLM treats "k" differently for top-k sampling
|
||||
vllm_top_k = -1
|
||||
else:
|
||||
vllm_top_k = sampling_params.strategy.top_k
|
||||
else:
|
||||
vllm_top_k = -1
|
||||
|
||||
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
vllm_top_p = sampling_params.strategy.top_p
|
||||
# Llama Stack only allows temperature with top-P.
|
||||
vllm_temperature = sampling_params.strategy.temperature
|
||||
else:
|
||||
vllm_top_p = 1.0
|
||||
vllm_temperature = 0.0
|
||||
|
||||
# vLLM allows top-p and top-k at the same time.
|
||||
vllm_sampling_params = vllm.SamplingParams.from_optional(
|
||||
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
|
||||
temperature=vllm_temperature,
|
||||
top_p=vllm_top_p,
|
||||
top_k=vllm_top_k,
|
||||
repetition_penalty=sampling_params.repetition_penalty,
|
||||
guided_decoding=_response_format_to_guided_decoding_params(response_format),
|
||||
logprobs=log_prob_config.top_k,
|
||||
)
|
||||
return vllm_sampling_params
|
||||
|
||||
|
||||
class VLLMInferenceImpl(
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
"""
|
||||
vLLM-based inference model adapter for Llama Stack with support for multiple models.
|
||||
|
||||
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
|
||||
"""
|
||||
|
||||
config: VLLMConfig
|
||||
register_helper: ModelRegistryHelper
|
||||
model_ids: set[str]
|
||||
resolved_model_id: str | None
|
||||
engine: AsyncLLMEngine | None
|
||||
chat: OpenAIServingChat | None
|
||||
is_meta_llama_model: bool
|
||||
|
||||
def __init__(self, config: VLLMConfig):
|
||||
self.config = config
|
||||
logger.info(f"Config is: {self.config}")
|
||||
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# The following are initialized when paths are bound to this provider
|
||||
self.resolved_model_id = None
|
||||
self.model_ids = set()
|
||||
self.engine = None
|
||||
self.chat = None
|
||||
self.is_meta_llama_model = False
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
|
||||
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Callback that is invoked through many levels of indirection during provider class
|
||||
instantiation, sometime after when __init__() is called and before any model registration
|
||||
methods or methods connected to a REST API are called.
|
||||
|
||||
It's not clear what assumptions the class can make about the platform's initialization
|
||||
state here that can't be made during __init__(), and vLLM can't be started until we know
|
||||
what model it's supposed to be serving, so nothing happens here currently.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
logger.info(f"Shutting down inline vLLM inference provider {self}.")
|
||||
if self.engine is not None:
|
||||
self.engine.shutdown_background_loop()
|
||||
self.engine = None
|
||||
self.chat = None
|
||||
self.model_ids = set()
|
||||
self.resolved_model_id = None
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
|
||||
|
||||
# Note that the return type of the superclass method is WRONG
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
"""
|
||||
Callback that is called when the server associates an inference endpoint with an
|
||||
inference provider.
|
||||
|
||||
:param model: Object that encapsulates parameters necessary for identifying a specific
|
||||
LLM.
|
||||
|
||||
:returns: The input ``Model`` object. It may or may not be permissible to change fields
|
||||
before returning this object.
|
||||
"""
|
||||
logger.debug(f"In register_model({model})")
|
||||
|
||||
# First attempt to interpret the model coordinates as a Llama model name
|
||||
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
|
||||
if resolved_llama_model is not None:
|
||||
# Load from Hugging Face repo into default local cache dir
|
||||
model_id_for_vllm = resolved_llama_model.huggingface_repo
|
||||
|
||||
# Detect a genuine Meta Llama model to trigger Meta-specific preprocessing.
|
||||
# Don't set self.is_meta_llama_model until we actually load the model.
|
||||
is_meta_llama_model = True
|
||||
else: # if resolved_llama_model is None
|
||||
# Not a Llama model name. Pass the model id through to vLLM's loader
|
||||
model_id_for_vllm = model.provider_model_id
|
||||
is_meta_llama_model = False
|
||||
|
||||
if self.resolved_model_id is not None:
|
||||
if model_id_for_vllm != self.resolved_model_id:
|
||||
raise ValueError(
|
||||
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
|
||||
f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple "
|
||||
f"copies of the provider instead."
|
||||
)
|
||||
else:
|
||||
# Model already loaded
|
||||
logger.info(
|
||||
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
|
||||
)
|
||||
self.model_ids.add(model.model_id)
|
||||
return model
|
||||
|
||||
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
|
||||
if is_meta_llama_model:
|
||||
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
|
||||
self.is_meta_llama_model = is_meta_llama_model
|
||||
|
||||
# If we get here, this is the first time registering a model.
|
||||
# Preload so that the first inference request won't time out.
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model_id_for_vllm,
|
||||
tokenizer=model_id_for_vllm,
|
||||
tensor_parallel_size=self.config.tensor_parallel_size,
|
||||
enforce_eager=self.config.enforce_eager,
|
||||
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
||||
max_num_seqs=self.config.max_num_seqs,
|
||||
max_model_len=self.config.max_model_len,
|
||||
)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
|
||||
# parser, we need to determine what model architecture is being used. For now, we infer
|
||||
# that information from what config class the model uses.
|
||||
low_level_model_config = self.engine.engine.get_model_config()
|
||||
hf_config = low_level_model_config.hf_config
|
||||
hf_config_class_name = hf_config.__class__.__name__
|
||||
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
|
||||
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
|
||||
else:
|
||||
# No info -- choose a default so we can at least attempt tool
|
||||
# use.
|
||||
tool_parser = DEFAULT_TOOL_PARSER
|
||||
logger.debug(f"{hf_config_class_name=}")
|
||||
logger.debug(f"{tool_parser=}")
|
||||
|
||||
# Wrap the lower-level engine in an OpenAI-compatible chat API
|
||||
model_config = await self.engine.get_model_config()
|
||||
self.chat = OpenAIServingChat(
|
||||
engine_client=self.engine,
|
||||
model_config=model_config,
|
||||
models=OpenAIServingModels(
|
||||
engine_client=self.engine,
|
||||
model_config=model_config,
|
||||
base_model_paths=[
|
||||
# The layer below us will only see resolved model IDs
|
||||
BaseModelPath(model_id_for_vllm, model_id_for_vllm)
|
||||
],
|
||||
),
|
||||
response_role="assistant",
|
||||
request_logger=None, # Use default logging
|
||||
chat_template=None, # Use default template from model checkpoint
|
||||
enable_auto_tools=True,
|
||||
tool_parser=tool_parser,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
self.resolved_model_id = model_id_for_vllm
|
||||
self.model_ids.add(model.model_id)
|
||||
|
||||
logger.info(f"Finished preloading model: {model_id_for_vllm}")
|
||||
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
"""
|
||||
Callback that is called when the server removes an inference endpoint from an inference
|
||||
provider.
|
||||
|
||||
:param model_id: The same external ID that the higher layers of the stack previously passed
|
||||
to :func:`register_model()`
|
||||
"""
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
|
||||
)
|
||||
self.model_ids.remove(model_id)
|
||||
|
||||
if len(self.model_ids) == 0:
|
||||
# Last model was just unregistered. Shut down the connection to vLLM and free up
|
||||
# resources.
|
||||
# Note that this operation may cause in-flight chat completion requests on the
|
||||
# now-unregistered model to return errors.
|
||||
self.resolved_model_id = None
|
||||
self.chat = None
|
||||
self.engine.shutdown_background_loop()
|
||||
self.engine = None
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM Inference INTERFACE
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
if not isinstance(content, str):
|
||||
raise NotImplementedError("Multimodal input not currently supported")
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
||||
|
||||
logger.debug(f"{converted_sampling_params=}")
|
||||
|
||||
if stream:
|
||||
return self._streaming_completion(content, converted_sampling_params)
|
||||
else:
|
||||
streaming_result = None
|
||||
async for _ in self._streaming_completion(content, converted_sampling_params):
|
||||
pass
|
||||
return CompletionResponse(
|
||||
content=streaming_result.delta,
|
||||
stop_reason=streaming_result.stop_reason,
|
||||
logprobs=streaming_result.logprobs,
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: list[Message], # type: ignore
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None, # type: ignore
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
sampling_params = sampling_params or SamplingParams()
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
|
||||
# Convert to Llama Stack internal format for consistency
|
||||
request = ChatCompletionRequest(
|
||||
model=self.resolved_model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
if self.is_meta_llama_model:
|
||||
# Bypass vLLM chat templating layer for Meta Llama models, because the
|
||||
# templating layer in Llama Stack currently produces better results.
|
||||
logger.debug(
|
||||
f"Routing {self.resolved_model_id} chat completion through "
|
||||
f"Llama Stack's templating layer instead of vLLM's."
|
||||
)
|
||||
return await self._chat_completion_for_meta_llama(request)
|
||||
|
||||
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
|
||||
|
||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
|
||||
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
|
||||
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
|
||||
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
|
||||
|
||||
logger.debug(f"Converted request: {chat_completion_request}")
|
||||
|
||||
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
|
||||
logger.debug(f"Result from vLLM: {vllm_result}")
|
||||
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
|
||||
raise ValueError(f"Error from vLLM layer: {vllm_result}")
|
||||
|
||||
# Return type depends on "stream" argument
|
||||
if stream:
|
||||
if not isinstance(vllm_result, AsyncGenerator):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
||||
# vLLM client returns a stream of strings, which need to be parsed.
|
||||
# Stream comes in the form of an async generator.
|
||||
return self._convert_streaming_results(vllm_result)
|
||||
else:
|
||||
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
|
||||
return self._convert_non_streaming_results(vllm_result)
|
||||
|
||||
###########################################################################
|
||||
# INTERNAL METHODS
|
||||
|
||||
async def _streaming_completion(
|
||||
self, content: str, sampling_params: vllm.SamplingParams
|
||||
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
||||
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
|
||||
that arguments have been validated upstream.
|
||||
|
||||
:param content: Must be a string
|
||||
:param sampling_params: Paramters from public API's ``response_format``
|
||||
and ``sampling_params`` arguments, converted to VLLM format
|
||||
"""
|
||||
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
|
||||
# layer, because doing so simplifies the code here.
|
||||
|
||||
# The vLLM engine requires a unique identifier for each call to generate()
|
||||
request_id = _random_uuid_str()
|
||||
|
||||
# The vLLM generate() API is streaming-only and returns an async generator.
|
||||
# The generator returns objects of type vllm.RequestOutput.
|
||||
results_generator = self.engine.generate(content, sampling_params, request_id)
|
||||
|
||||
# Need to know the model's EOS token ID for the conversion code below.
|
||||
# AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if
|
||||
# we drill down to the LLMEngine inside the AsyncLLMEngine.
|
||||
# Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup,
|
||||
# and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup.
|
||||
llm_engine = self.engine.engine
|
||||
tokenizer_group = llm_engine.tokenizer
|
||||
eos_token_id = tokenizer_group.tokenizer.eos_token_id
|
||||
|
||||
request_output: vllm.RequestOutput = None
|
||||
async for request_output in results_generator:
|
||||
# Check for weird inference failures
|
||||
if request_output.outputs is None or len(request_output.outputs) == 0:
|
||||
# This case also should never happen
|
||||
raise ValueError("Inference produced empty result")
|
||||
|
||||
# If we get here, then request_output contains the final output of the generate() call.
|
||||
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
|
||||
# us to return one.
|
||||
output: vllm.CompletionOutput = request_output.outputs[0]
|
||||
completion_string = output.text
|
||||
|
||||
# Convert logprobs from vLLM's format to Llama Stack's format
|
||||
logprobs = [
|
||||
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
|
||||
for logprob_dict in output.logprobs
|
||||
]
|
||||
|
||||
# The final output chunk should be labeled with the reason that the overall generate()
|
||||
# call completed.
|
||||
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
|
||||
if output.stop_reason is None:
|
||||
stop_reason = None # Still going
|
||||
elif output.stop_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif output.stop_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
elif isinstance(output.stop_reason, int):
|
||||
# If the model config specifies multiple end-of-sequence tokens, then vLLM
|
||||
# will return the token ID of the EOS token in the stop_reason field.
|
||||
stop_reason = StopReason.end_of_turn
|
||||
else:
|
||||
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
|
||||
|
||||
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
|
||||
# some reason.
|
||||
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
||||
|
||||
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
|
||||
# provide one if it runs out of tokens.
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=completion_string,
|
||||
stop_reason=StopReason.out_of_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
def _convert_non_streaming_results(
|
||||
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
|
||||
equivalent Llama Stack object.
|
||||
|
||||
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
|
||||
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
|
||||
the fields that aren't currently present in the Llama Stack dataclass.
|
||||
"""
|
||||
|
||||
# There may be multiple responses, but we can only pass through the first one.
|
||||
if len(vllm_result.choices) == 0:
|
||||
raise ValueError("Don't know how to convert response object without any responses")
|
||||
vllm_message = vllm_result.choices[0].message
|
||||
vllm_finish_reason = vllm_result.choices[0].finish_reason
|
||||
|
||||
converted_message = CompletionMessage(
|
||||
role=vllm_message.role,
|
||||
# Llama Stack API won't accept None for content field.
|
||||
content=("" if vllm_message.content is None else vllm_message.content),
|
||||
stop_reason=get_stop_reason(vllm_finish_reason),
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id=t.id,
|
||||
tool_name=t.function.name,
|
||||
# vLLM function args come back as a string. Llama Stack expects JSON.
|
||||
arguments=json.loads(t.function.arguments),
|
||||
arguments_json=t.function.arguments,
|
||||
)
|
||||
for t in vllm_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Convert logprobs
|
||||
|
||||
logger.debug(f"Converted message: {converted_message}")
|
||||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=converted_message,
|
||||
)
|
||||
|
||||
async def _chat_completion_for_meta_llama(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
"""
|
||||
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
|
||||
chat template instead of using vLLM's version of that template. The Llama Stack version
|
||||
of the chat template currently produces more reliable outputs.
|
||||
|
||||
Once vLLM's support for Meta Llama models has matured more, we should consider routing
|
||||
Meta Llama requests through the vLLM chat completions API instead of using this method.
|
||||
"""
|
||||
formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
# Note that this function call modifies `request` in place.
|
||||
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
|
||||
|
||||
model_id = list(self.model_ids)[0] # Any model ID will do here
|
||||
completion_response_or_iterator = await self.completion(
|
||||
model_id=model_id,
|
||||
content=prompt,
|
||||
sampling_params=request.sampling_params,
|
||||
response_format=request.response_format,
|
||||
stream=request.stream,
|
||||
logprobs=request.logprobs,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
if not isinstance(completion_response_or_iterator, AsyncIterator):
|
||||
raise TypeError(
|
||||
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
|
||||
)
|
||||
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
|
||||
|
||||
# elsif not request.stream:
|
||||
if not isinstance(completion_response_or_iterator, CompletionResponse):
|
||||
raise TypeError(
|
||||
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
|
||||
)
|
||||
completion_response: CompletionResponse = completion_response_or_iterator
|
||||
raw_message = formatter.decode_assistant_message_from_content(
|
||||
completion_response.content, completion_response.stop_reason
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=completion_response.logprobs,
|
||||
)
|
||||
|
||||
async def _chat_completion_for_meta_llama_streaming(
|
||||
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
|
||||
) -> AsyncIterator:
|
||||
"""
|
||||
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
|
||||
method to keep asyncio happy.
|
||||
"""
|
||||
|
||||
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
chunk: CompletionResponseStreamChunk # Make Pylance happy
|
||||
last_text_len = 0
|
||||
async for chunk in results_iterator:
|
||||
if chunk.stop_reason == StopReason.end_of_turn:
|
||||
finish_reason = "stop"
|
||||
elif chunk.stop_reason == StopReason.end_of_message:
|
||||
finish_reason = "eos"
|
||||
elif chunk.stop_reason == StopReason.out_of_tokens:
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = None
|
||||
|
||||
# Convert delta back to an actual delta
|
||||
text_delta = chunk.delta[last_text_len:]
|
||||
last_text_len = len(chunk.delta)
|
||||
|
||||
logger.debug(f"{text_delta=}; {finish_reason=}")
|
||||
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
logger.debug(f"Returning chunk: {chunk}")
|
||||
yield chunk
|
||||
|
||||
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
|
||||
"""
|
||||
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
|
||||
API into a second async iterator that returns Llama Stack objects.
|
||||
|
||||
:param vllm_result: Stream of strings that need to be parsed
|
||||
"""
|
||||
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
|
||||
# those chunks and output them at the end.
|
||||
# This data structure holds the current set of partial tool calls.
|
||||
index_to_tool_call: dict[int, dict] = dict()
|
||||
|
||||
# The Llama Stack event stream must always start with a start event. Use an empty one to
|
||||
# simplify logic below
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=None,
|
||||
)
|
||||
)
|
||||
|
||||
converted_stop_reason = None
|
||||
async for chunk_str in vllm_result:
|
||||
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
|
||||
# end with "\n\n".
|
||||
_prefix = "data: "
|
||||
_suffix = "\n\n"
|
||||
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
|
||||
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
|
||||
|
||||
# In between the "data: " and newlines is an event record
|
||||
data_str = chunk_str[len(_prefix) : -len(_suffix)]
|
||||
|
||||
# The end of the stream is indicated with "[DONE]"
|
||||
if data_str == "[DONE]":
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Anything that is not "[DONE]" should be a JSON record
|
||||
parsed_chunk = json.loads(data_str)
|
||||
|
||||
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||
|
||||
# The result may contain multiple completions, but Llama Stack APIs only support
|
||||
# returning one.
|
||||
first_choice = parsed_chunk["choices"][0]
|
||||
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
|
||||
delta_record = first_choice["delta"]
|
||||
|
||||
if "content" in delta_record:
|
||||
# Text delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text=delta_record["content"]),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
elif "tool_calls" in delta_record:
|
||||
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
|
||||
# calls, so buffer until we get a "tool calls" stop reason
|
||||
for tc in delta_record["tool_calls"]:
|
||||
index = tc["index"]
|
||||
if index not in index_to_tool_call:
|
||||
# First time this tool call is showing up
|
||||
index_to_tool_call[index] = dict()
|
||||
tool_call = index_to_tool_call[index]
|
||||
if "id" in tc:
|
||||
tool_call["call_id"] = tc["id"]
|
||||
if "function" in tc:
|
||||
if "name" in tc["function"]:
|
||||
tool_call["tool_name"] = tc["function"]["name"]
|
||||
if "arguments" in tc["function"]:
|
||||
# Arguments comes in as pieces of a string
|
||||
if "arguments_str" not in tool_call:
|
||||
tool_call["arguments_str"] = ""
|
||||
tool_call["arguments_str"] += tc["function"]["arguments"]
|
||||
else:
|
||||
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
|
||||
|
||||
if first_choice["finish_reason"] == "tool_calls":
|
||||
# Special OpenAI code for "tool calls complete".
|
||||
# Output the buffered tool calls. Llama Stack requires a separate event per tool
|
||||
# call.
|
||||
for tool_call_record in index_to_tool_call.values():
|
||||
# Arguments come in as a string. Parse the completed string.
|
||||
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
|
||||
del tool_call_record["arguments_str"]
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
|
||||
stop_reason=converted_stop_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# If we get here, we've lost the connection with the vLLM event stream before it ended
|
||||
# normally.
|
||||
raise ValueError("vLLM event stream ended without [DONE] message.")
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
import struct
|
||||
from typing import Any
|
||||
|
|
@ -117,6 +118,10 @@ def _rrf_rerank(
|
|||
return rrf_scores
|
||||
|
||||
|
||||
def _make_sql_identifier(name: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
|
||||
|
||||
class SQLiteVecIndex(EmbeddingIndex):
|
||||
"""
|
||||
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
||||
|
|
@ -130,9 +135,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
self.dimension = dimension
|
||||
self.db_path = db_path
|
||||
self.bank_id = bank_id
|
||||
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
|
||||
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
|
||||
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
|
||||
self.metadata_table = _make_sql_identifier(f"chunks_{bank_id}")
|
||||
self.vector_table = _make_sql_identifier(f"vec_chunks_{bank_id}")
|
||||
self.fts_table = _make_sql_identifier(f"fts_chunks_{bank_id}")
|
||||
self.kvstore = kvstore
|
||||
|
||||
@classmethod
|
||||
|
|
@ -148,14 +153,14 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
try:
|
||||
# Create the table to store chunk metadata.
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
|
||||
CREATE TABLE IF NOT EXISTS [{self.metadata_table}] (
|
||||
id TEXT PRIMARY KEY,
|
||||
chunk TEXT
|
||||
);
|
||||
""")
|
||||
# Create the virtual table for embeddings.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.vector_table}]
|
||||
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
|
||||
""")
|
||||
connection.commit()
|
||||
|
|
@ -163,7 +168,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
# based on query. Implementation of the change on client side will allow passing the search_mode option
|
||||
# during initialization to make it easier to create the table that is required.
|
||||
cur.execute(f"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table}
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.fts_table}]
|
||||
USING fts5(id, content);
|
||||
""")
|
||||
connection.commit()
|
||||
|
|
@ -178,9 +183,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
connection = _create_sqlite_connection(self.db_path)
|
||||
cur = connection.cursor()
|
||||
try:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};")
|
||||
cur.execute(f"DROP TABLE IF EXISTS [{self.metadata_table}];")
|
||||
cur.execute(f"DROP TABLE IF EXISTS [{self.vector_table}];")
|
||||
cur.execute(f"DROP TABLE IF EXISTS [{self.fts_table}];")
|
||||
connection.commit()
|
||||
finally:
|
||||
cur.close()
|
||||
|
|
@ -212,7 +217,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
|
||||
cur.executemany(
|
||||
f"""
|
||||
INSERT INTO {self.metadata_table} (id, chunk)
|
||||
INSERT INTO [{self.metadata_table}] (id, chunk)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
|
||||
""",
|
||||
|
|
@ -230,7 +235,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
]
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);",
|
||||
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
|
||||
embedding_data,
|
||||
)
|
||||
|
||||
|
|
@ -238,13 +243,13 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
|
||||
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
|
||||
cur.executemany(
|
||||
f"DELETE FROM {self.fts_table} WHERE id = ?;",
|
||||
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
|
||||
[(row[0],) for row in fts_data],
|
||||
)
|
||||
|
||||
# INSERT new entries
|
||||
cur.executemany(
|
||||
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);",
|
||||
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
|
||||
fts_data,
|
||||
)
|
||||
|
||||
|
|
@ -280,8 +285,8 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
emb_blob = serialize_vector(emb_list)
|
||||
query_sql = f"""
|
||||
SELECT m.id, m.chunk, v.distance
|
||||
FROM {self.vector_table} AS v
|
||||
JOIN {self.metadata_table} AS m ON m.id = v.id
|
||||
FROM [{self.vector_table}] AS v
|
||||
JOIN [{self.metadata_table}] AS m ON m.id = v.id
|
||||
WHERE v.embedding MATCH ? AND k = ?
|
||||
ORDER BY v.distance;
|
||||
"""
|
||||
|
|
@ -322,9 +327,9 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
cur = connection.cursor()
|
||||
try:
|
||||
query_sql = f"""
|
||||
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
|
||||
FROM {self.fts_table} AS f
|
||||
JOIN {self.metadata_table} AS m ON m.id = f.id
|
||||
SELECT DISTINCT m.id, m.chunk, bm25([{self.fts_table}]) AS score
|
||||
FROM [{self.fts_table}] AS f
|
||||
JOIN [{self.metadata_table}] AS m ON m.id = f.id
|
||||
WHERE f.content MATCH ?
|
||||
ORDER BY score ASC
|
||||
LIMIT ?;
|
||||
|
|
|
|||
|
|
@ -37,16 +37,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
|
||||
description="Meta's reference implementation of inference with support for various model formats and optimization techniques.",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::vllm",
|
||||
pip_packages=[
|
||||
"vllm",
|
||||
],
|
||||
module="llama_stack.providers.inline.inference.vllm",
|
||||
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
||||
description="vLLM inference provider for high-performance model serving with PagedAttention and continuous batching.",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::sentence-transformers",
|
||||
|
|
|
|||
|
|
@ -3,16 +3,17 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
|
||||
LlamaCompatConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
|
||||
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: LlamaCompatConfig
|
||||
|
|
@ -27,8 +28,32 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
)
|
||||
self.config = config
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from Llama API.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
try:
|
||||
llama_api_client = self._get_llama_api_client()
|
||||
retrieved_model = await llama_api_client.models.retrieve(model)
|
||||
logger.info(f"Model {retrieved_model.id} is available from Llama API")
|
||||
return True
|
||||
|
||||
except NotFoundError:
|
||||
logger.error(f"Model {model} is not available from Llama API")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability from Llama API: {e}")
|
||||
return False
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
||||
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
|
||||
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)
|
||||
|
|
|
|||
|
|
@ -7,10 +7,9 @@
|
|||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
|
|
@ -41,11 +40,7 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference import (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
|
@ -93,41 +88,37 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
|
||||
self._config = config
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _get_client(self, provider_model_id: str) -> AsyncOpenAI:
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
|
||||
some models are hosted on different URLs. This function returns the appropriate client
|
||||
for the given provider_model_id.
|
||||
Check if a specific model is available.
|
||||
|
||||
This relies on lru_cache and self._default_client to avoid creating a new client for each request
|
||||
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
try:
|
||||
await self._client.models.retrieve(model)
|
||||
return True
|
||||
except NotFoundError:
|
||||
logger.error(f"Model {model} is not available")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability: {e}")
|
||||
return False
|
||||
|
||||
@property
|
||||
def _client(self) -> AsyncOpenAI:
|
||||
"""
|
||||
Returns an OpenAI client for the configured NVIDIA API endpoint.
|
||||
|
||||
:param provider_model_id: The provider model ID
|
||||
:return: An OpenAI client
|
||||
"""
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
|
||||
"""
|
||||
Maintain a single OpenAI client per base_url.
|
||||
"""
|
||||
return AsyncOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
special_model_urls = {
|
||||
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
|
||||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||
}
|
||||
|
||||
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||
base_url = special_model_urls[provider_model_id]
|
||||
return _get_client_for_base_url(base_url)
|
||||
return AsyncOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||
if not self.model_store:
|
||||
|
|
@ -169,7 +160,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._get_client(provider_model_id).completions.create(**request)
|
||||
response = await self._client.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -222,7 +213,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self._get_client(provider_model_id).embeddings.create(
|
||||
response = await self._client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
|
|
@ -283,7 +274,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._get_client(provider_model_id).chat.completions.create(**request)
|
||||
response = await self._client.chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -339,7 +330,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
return await self._get_client(provider_model_id).completions.create(**params)
|
||||
return await self._client.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -398,47 +389,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
||||
return await self._client.chat.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
"""
|
||||
Allow non-llama model registration.
|
||||
|
||||
Non-llama model registration: API Catalogue models, post-training models, etc.
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.models.register(
|
||||
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
|
||||
model_type=ModelType.llm,
|
||||
provider_id="nvidia",
|
||||
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
|
||||
)
|
||||
|
||||
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
|
||||
"""
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
else:
|
||||
llama_model = model.metadata.get("llama_model")
|
||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||
if existing_llama_model:
|
||||
if existing_llama_model != llama_model:
|
||||
raise ValueError(
|
||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
# not llama model
|
||||
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
else:
|
||||
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -6,13 +6,15 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||
|
||||
|
||||
class OllamaImplConfig(BaseModel):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
|
||||
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
|
|
@ -91,23 +92,88 @@ class OllamaInferenceAdapter(
|
|||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||
self.url = config.url
|
||||
self.config = config
|
||||
self._client = None
|
||||
self._openai_client = None
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
return AsyncClient(host=self.url)
|
||||
if self._client is None:
|
||||
self._client = AsyncClient(host=self.config.url)
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
|
||||
if self._openai_client is None:
|
||||
self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama")
|
||||
return self._openai_client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug(f"checking connectivity to Ollama at `{self.url}`...")
|
||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
health_response = await self.health()
|
||||
if health_response["status"] == HealthStatus.ERROR:
|
||||
raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
|
||||
logger.warning(
|
||||
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
if self.config.refresh_models:
|
||||
logger.debug("ollama starting background model refresh task")
|
||||
self._refresh_task = asyncio.create_task(self._refresh_models())
|
||||
|
||||
def cb(task):
|
||||
if task.cancelled():
|
||||
import traceback
|
||||
|
||||
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
|
||||
elif task.exception():
|
||||
logger.error(f"ollama background refresh task died: {task.exception()}")
|
||||
else:
|
||||
logger.error("ollama background refresh task completed unexpectedly")
|
||||
|
||||
self._refresh_task.add_done_callback(cb)
|
||||
|
||||
async def _refresh_models(self) -> None:
|
||||
# Wait for model store to be available (with timeout)
|
||||
waited_time = 0
|
||||
while not self.model_store and waited_time < 60:
|
||||
await asyncio.sleep(1)
|
||||
waited_time += 1
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set after waiting 60 seconds")
|
||||
|
||||
provider_id = self.__provider_id__
|
||||
while True:
|
||||
try:
|
||||
response = await self.client.list()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list models: {str(e)}")
|
||||
await asyncio.sleep(self.config.refresh_models_interval)
|
||||
continue
|
||||
|
||||
models = []
|
||||
for m in response.models:
|
||||
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
|
||||
if model_type == ModelType.embedding:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.model,
|
||||
provider_resource_id=m.model,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
await self.model_store.update_registered_llm_models(provider_id, models)
|
||||
logger.debug(f"ollama refreshed model list ({len(models)} models)")
|
||||
|
||||
await asyncio.sleep(self.config.refresh_models_interval)
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
|
|
@ -124,7 +190,12 @@ class OllamaInferenceAdapter(
|
|||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
if hasattr(self, "_refresh_task") and not self._refresh_task.done():
|
||||
logger.debug("ollama cancelling background refresh task")
|
||||
self._refresh_task.cancel()
|
||||
|
||||
self._client = None
|
||||
self._openai_client = None
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
@ -354,8 +425,6 @@ class OllamaInferenceAdapter(
|
|||
raise ValueError("Model provider_resource_id cannot be None")
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||
# TODO: you should pull here only if the model is not found in a list
|
||||
response = await self.client.list()
|
||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import logging
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai import AsyncOpenAI, NotFoundError
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletion,
|
||||
|
|
@ -60,6 +60,27 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
# litellm specific model names, an abstraction leak.
|
||||
self.is_openai_compat = True
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from OpenAI.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
try:
|
||||
openai_client = self._get_openai_client()
|
||||
retrieved_model = await openai_client.models.retrieve(model)
|
||||
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
|
||||
return True
|
||||
|
||||
except NotFoundError:
|
||||
logger.error(f"Model {model} is not available from OpenAI")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability from OpenAI: {e}")
|
||||
return False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,14 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
default=True,
|
||||
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
|
||||
)
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
refresh_models_interval: int = Field(
|
||||
default=300,
|
||||
description="Interval in seconds to refresh models",
|
||||
)
|
||||
|
||||
@field_validator("tls_verify")
|
||||
@classmethod
|
||||
|
|
@ -46,7 +54,7 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.VLLM_URL}",
|
||||
url: str = "${env.VLLM_URL:=}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -38,6 +38,7 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ModelStore,
|
||||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingData,
|
||||
|
|
@ -54,6 +55,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import (
|
||||
|
|
@ -84,7 +86,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
def build_hf_repo_model_entries():
|
||||
|
|
@ -288,16 +290,76 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
|
||||
|
||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
model_store: ModelStore | None = None
|
||||
_refresh_task: asyncio.Task | None = None
|
||||
|
||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
self.config = config
|
||||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
if not self.config.url:
|
||||
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
|
||||
# or available in distributions like "starter" without causing a ruckus
|
||||
return
|
||||
|
||||
if self.config.refresh_models:
|
||||
self._refresh_task = asyncio.create_task(self._refresh_models())
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
|
||||
elif task.exception():
|
||||
# print the stack trace for the exception
|
||||
exc = task.exception()
|
||||
log.error(f"vLLM background refresh task died: {exc}")
|
||||
traceback.print_exception(exc)
|
||||
else:
|
||||
log.error("vLLM background refresh task completed unexpectedly")
|
||||
|
||||
self._refresh_task.add_done_callback(cb)
|
||||
|
||||
async def _refresh_models(self) -> None:
|
||||
provider_id = self.__provider_id__
|
||||
waited_time = 0
|
||||
while not self.model_store and waited_time < 60:
|
||||
await asyncio.sleep(1)
|
||||
waited_time += 1
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set after waiting 60 seconds")
|
||||
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None # mypy
|
||||
while True:
|
||||
try:
|
||||
models = []
|
||||
async for m in self.client.models.list():
|
||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.id,
|
||||
provider_resource_id=m.id,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
await self.model_store.update_registered_llm_models(provider_id, models)
|
||||
log.debug(f"vLLM refreshed model list ({len(models)} models)")
|
||||
except Exception as e:
|
||||
log.error(f"vLLM background refresh task failed: {e}")
|
||||
await asyncio.sleep(self.config.refresh_models_interval)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
if self._refresh_task:
|
||||
self._refresh_task.cancel()
|
||||
self._refresh_task = None
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
@ -312,6 +374,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
if not self.config.url:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")
|
||||
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
_ = [m async for m in client.models.list()] # Ensure the client is initialized
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
|
|
@ -327,6 +392,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if self.client is not None:
|
||||
return
|
||||
|
||||
if not self.config.url:
|
||||
raise ValueError(
|
||||
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
|
||||
)
|
||||
|
||||
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||
self.client = self._create_client()
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -39,7 +38,6 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
|
@ -90,12 +88,6 @@ class LiteLLMOpenAIMixin(
|
|||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
if model_id is None:
|
||||
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
|
||||
return model
|
||||
|
||||
def get_litellm_model_name(self, model_id: str) -> str:
|
||||
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
|
||||
# model_id.startswith("openai/") is for backwards compatibility.
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .vllm import get_distribution_template # noqa: F401
|
||||
from .ci_tests import get_distribution_template # noqa: F401
|
||||
65
llama_stack/templates/ci-tests/build.yaml
Normal file
65
llama_stack/templates/ci-tests/build.yaml
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
version: 2
|
||||
distribution_spec:
|
||||
description: CI tests for Llama Stack
|
||||
providers:
|
||||
inference:
|
||||
- remote::cerebras
|
||||
- remote::ollama
|
||||
- remote::vllm
|
||||
- remote::tgi
|
||||
- remote::hf::serverless
|
||||
- remote::hf::endpoint
|
||||
- remote::fireworks
|
||||
- remote::together
|
||||
- remote::bedrock
|
||||
- remote::databricks
|
||||
- remote::nvidia
|
||||
- remote::runpod
|
||||
- remote::openai
|
||||
- remote::anthropic
|
||||
- remote::gemini
|
||||
- remote::groq
|
||||
- remote::fireworks-openai-compat
|
||||
- remote::llama-openai-compat
|
||||
- remote::together-openai-compat
|
||||
- remote::groq-openai-compat
|
||||
- remote::sambanova-openai-compat
|
||||
- remote::cerebras-openai-compat
|
||||
- remote::sambanova
|
||||
- remote::passthrough
|
||||
- inline::sentence-transformers
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- inline::sqlite-vec
|
||||
- inline::milvus
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
files:
|
||||
- inline::localfs
|
||||
safety:
|
||||
- inline::llama-guard
|
||||
agents:
|
||||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
post_training:
|
||||
- inline::huggingface
|
||||
eval:
|
||||
- inline::meta-reference
|
||||
datasetio:
|
||||
- remote::huggingface
|
||||
- inline::localfs
|
||||
scoring:
|
||||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- asyncpg
|
||||
- sqlalchemy[asyncio]
|
||||
19
llama_stack/templates/ci-tests/ci_tests.py
Normal file
19
llama_stack/templates/ci-tests/ci_tests.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.templates.template import DistributionTemplate
|
||||
|
||||
from ..starter.starter import get_distribution_template as get_starter_distribution_template
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
template = get_starter_distribution_template()
|
||||
name = "ci-tests"
|
||||
template.name = name
|
||||
template.description = "CI tests for Llama Stack"
|
||||
|
||||
return template
|
||||
1189
llama_stack/templates/ci-tests/run.yaml
Normal file
1189
llama_stack/templates/ci-tests/run.yaml
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -26,7 +26,7 @@ providers:
|
|||
- provider_id: ${env.ENABLE_VLLM:=__disabled__}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL}
|
||||
url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
|
|
@ -262,6 +262,11 @@ inference_store:
|
|||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db
|
||||
models:
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
|
||||
model_type: embedding
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama3.1-8b
|
||||
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
|
||||
|
|
@ -1168,11 +1173,6 @@ models:
|
|||
provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
|
||||
provider_model_id: sambanova/Meta-Llama-Guard-3-8B
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
|
||||
model_type: embedding
|
||||
shields:
|
||||
- shield_id: ${env.SAFETY_MODEL:=__disabled__}
|
||||
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
|
||||
|
|
|
|||
|
|
@ -323,7 +323,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"files": [files_provider],
|
||||
"post_training": [post_training_provider],
|
||||
},
|
||||
default_models=default_models + [embedding_model],
|
||||
default_models=[embedding_model] + default_models,
|
||||
default_tool_groups=default_tool_groups,
|
||||
# TODO: add a way to enable/disable shields on the fly
|
||||
default_shields=shields,
|
||||
|
|
|
|||
|
|
@ -1,35 +0,0 @@
|
|||
version: 2
|
||||
distribution_spec:
|
||||
description: Use a built-in vLLM engine for running LLM inference
|
||||
providers:
|
||||
inference:
|
||||
- inline::vllm
|
||||
- inline::sentence-transformers
|
||||
vector_io:
|
||||
- inline::faiss
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety:
|
||||
- inline::llama-guard
|
||||
agents:
|
||||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
eval:
|
||||
- inline::meta-reference
|
||||
datasetio:
|
||||
- remote::huggingface
|
||||
- inline::localfs
|
||||
scoring:
|
||||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
@ -1,132 +0,0 @@
|
|||
version: 2
|
||||
image_name: vllm-gpu
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: vllm
|
||||
provider_type: inline::vllm
|
||||
config:
|
||||
tensor_parallel_size: ${env.TENSOR_PARALLEL_SIZE:=1}
|
||||
max_tokens: ${env.MAX_TOKENS:=4096}
|
||||
max_model_len: ${env.MAX_MODEL_LEN:=4096}
|
||||
max_num_seqs: ${env.MAX_NUM_SEQS:=4}
|
||||
enforce_eager: ${env.ENFORCE_EAGER:=False}
|
||||
gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:=0.3}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
config: {}
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: vllm
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
model_type: embedding
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import ModelInput, Provider
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.vllm import VLLMConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.templates.template import (
|
||||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
ToolGroupInput,
|
||||
)
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": ["inline::vllm", "inline::sentence-transformers"],
|
||||
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::rag-runtime",
|
||||
"remote::model-context-protocol",
|
||||
],
|
||||
}
|
||||
|
||||
name = "vllm-gpu"
|
||||
inference_provider = Provider(
|
||||
provider_id="vllm",
|
||||
provider_type="inline::vllm",
|
||||
config=VLLMConfig.sample_run_config(),
|
||||
)
|
||||
vector_io_provider = Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
inference_model = ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="vllm",
|
||||
)
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id="sentence-transformers",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
description="Use a built-in vLLM engine for running LLM inference",
|
||||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider, embedding_provider],
|
||||
"vector_io": [vector_io_provider],
|
||||
},
|
||||
default_models=[inference_model, embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"INFERENCE_MODEL": (
|
||||
"meta-llama/Llama-3.2-3B-Instruct",
|
||||
"Inference model loaded into the vLLM engine",
|
||||
),
|
||||
"TENSOR_PARALLEL_SIZE": (
|
||||
"1",
|
||||
"Number of tensor parallel replicas (number of GPUs to use).",
|
||||
),
|
||||
"MAX_TOKENS": (
|
||||
"4096",
|
||||
"Maximum number of tokens to generate.",
|
||||
),
|
||||
"ENFORCE_EAGER": (
|
||||
"False",
|
||||
"Whether to use eager mode for inference (otherwise cuda graphs are used).",
|
||||
),
|
||||
"GPU_MEMORY_UTILIZATION": (
|
||||
"0.7",
|
||||
"GPU memory utilization for the vLLM engine.",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
|
@ -29,6 +29,7 @@ dependencies = [
|
|||
"jinja2>=3.1.6",
|
||||
"jsonschema",
|
||||
"llama-stack-client>=0.2.15",
|
||||
"llama-api-client>=0.1.2",
|
||||
"openai>=1.66",
|
||||
"prompt-toolkit",
|
||||
"python-dotenv",
|
||||
|
|
@ -90,6 +91,7 @@ unit = [
|
|||
"pymilvus>=2.5.12",
|
||||
"litellm",
|
||||
"together",
|
||||
"coverage",
|
||||
]
|
||||
# These are the core dependencies required for running integration tests. They are shared across all
|
||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||
|
|
@ -241,7 +243,6 @@ exclude = [
|
|||
"^llama_stack/distribution/store/registry\\.py$",
|
||||
"^llama_stack/distribution/utils/exec\\.py$",
|
||||
"^llama_stack/distribution/utils/prompt_for_config\\.py$",
|
||||
"^llama_stack/models/llama/llama3/chat_format\\.py$",
|
||||
"^llama_stack/models/llama/llama3/interface\\.py$",
|
||||
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||
|
|
@ -254,10 +255,8 @@ exclude = [
|
|||
"^llama_stack/models/llama/llama3/generation\\.py$",
|
||||
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
||||
"^llama_stack/models/llama/llama4/",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
|
||||
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
||||
"^llama_stack/providers/inline/inference/vllm/",
|
||||
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
||||
"^llama_stack/providers/inline/safety/code_scanner/",
|
||||
"^llama_stack/providers/inline/safety/llama_guard/",
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ annotated-types==0.7.0
|
|||
anyio==4.8.0
|
||||
# via
|
||||
# httpx
|
||||
# llama-api-client
|
||||
# llama-stack-client
|
||||
# openai
|
||||
# starlette
|
||||
|
|
@ -49,6 +50,7 @@ deprecated==1.2.18
|
|||
# opentelemetry-semantic-conventions
|
||||
distro==1.9.0
|
||||
# via
|
||||
# llama-api-client
|
||||
# llama-stack-client
|
||||
# openai
|
||||
ecdsa==0.19.1
|
||||
|
|
@ -80,6 +82,7 @@ httpcore==1.0.9
|
|||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# llama-api-client
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
# openai
|
||||
|
|
@ -101,6 +104,8 @@ jsonschema==4.23.0
|
|||
# via llama-stack
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
llama-api-client==0.1.2
|
||||
# via llama-stack
|
||||
llama-stack-client==0.2.15
|
||||
# via llama-stack
|
||||
markdown-it-py==3.0.0
|
||||
|
|
@ -165,6 +170,7 @@ pycparser==2.22 ; platform_python_implementation != 'PyPy'
|
|||
pydantic==2.10.6
|
||||
# via
|
||||
# fastapi
|
||||
# llama-api-client
|
||||
# llama-stack
|
||||
# llama-stack-client
|
||||
# openai
|
||||
|
|
@ -215,6 +221,7 @@ six==1.17.0
|
|||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# llama-api-client
|
||||
# llama-stack-client
|
||||
# openai
|
||||
starlette==0.45.3
|
||||
|
|
@ -239,6 +246,7 @@ typing-extensions==4.12.2
|
|||
# anyio
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# llama-api-client
|
||||
# llama-stack-client
|
||||
# openai
|
||||
# opentelemetry-sdk
|
||||
|
|
|
|||
|
|
@ -16,4 +16,9 @@ if [ $FOUND_PYTHON -ne 0 ]; then
|
|||
uv python install "$PYTHON_VERSION"
|
||||
fi
|
||||
|
||||
uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest -s -v tests/unit/ $@
|
||||
# Run unit tests with coverage
|
||||
uv run --python "$PYTHON_VERSION" --with-editable . --group unit \
|
||||
coverage run --source=llama_stack -m pytest -s -v tests/unit/ "$@"
|
||||
|
||||
# Generate HTML coverage report
|
||||
uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION
|
||||
|
|
|
|||
|
|
@ -5,17 +5,20 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from io import BytesIO
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.distribution.datatypes import User
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
|
||||
def test_openai_client_basic_operations(openai_client, client_with_models):
|
||||
def test_openai_client_basic_operations(compat_client, client_with_models):
|
||||
"""Test basic file operations through OpenAI client."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI files are not supported when testing with library client yet.")
|
||||
client = openai_client
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient")
|
||||
client = compat_client
|
||||
|
||||
test_content = b"files test content"
|
||||
|
||||
|
|
@ -41,7 +44,12 @@ def test_openai_client_basic_operations(openai_client, client_with_models):
|
|||
# Retrieve file content - OpenAI client returns httpx Response object
|
||||
content_response = client.files.content(uploaded_file.id)
|
||||
# The response is an httpx Response object with .content attribute containing bytes
|
||||
content = content_response.content
|
||||
if isinstance(content_response, str):
|
||||
# Llama Stack Client returns a str
|
||||
# TODO: fix Llama Stack Client
|
||||
content = bytes(content_response, "utf-8")
|
||||
else:
|
||||
content = content_response.content
|
||||
assert content == test_content
|
||||
|
||||
# Delete file
|
||||
|
|
@ -55,3 +63,218 @@ def test_openai_client_basic_operations(openai_client, client_with_models):
|
|||
except Exception:
|
||||
pass
|
||||
raise e
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
def test_files_authentication_isolation(mock_get_authenticated_user, compat_client, client_with_models):
|
||||
"""Test that users can only access their own files."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient")
|
||||
if not isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Authentication tests require LlamaStackAsLibraryClient (library mode)")
|
||||
|
||||
client = compat_client
|
||||
|
||||
# Create two test users
|
||||
user1 = User("user1", {"roles": ["user"], "teams": ["team-a"]})
|
||||
user2 = User("user2", {"roles": ["user"], "teams": ["team-b"]})
|
||||
|
||||
# User 1 uploads a file
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
test_content_1 = b"User 1's private file content"
|
||||
|
||||
with BytesIO(test_content_1) as file_buffer:
|
||||
file_buffer.name = "user1_file.txt"
|
||||
user1_file = client.files.create(file=file_buffer, purpose="assistants")
|
||||
|
||||
# User 2 uploads a file
|
||||
mock_get_authenticated_user.return_value = user2
|
||||
test_content_2 = b"User 2's private file content"
|
||||
|
||||
with BytesIO(test_content_2) as file_buffer:
|
||||
file_buffer.name = "user2_file.txt"
|
||||
user2_file = client.files.create(file=file_buffer, purpose="assistants")
|
||||
|
||||
try:
|
||||
# User 1 can see their own file
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
user1_files = client.files.list()
|
||||
user1_file_ids = [f.id for f in user1_files.data]
|
||||
assert user1_file.id in user1_file_ids
|
||||
assert user2_file.id not in user1_file_ids # Cannot see user2's file
|
||||
|
||||
# User 2 can see their own file
|
||||
mock_get_authenticated_user.return_value = user2
|
||||
user2_files = client.files.list()
|
||||
user2_file_ids = [f.id for f in user2_files.data]
|
||||
assert user2_file.id in user2_file_ids
|
||||
assert user1_file.id not in user2_file_ids # Cannot see user1's file
|
||||
|
||||
# User 1 can retrieve their own file
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
retrieved_file = client.files.retrieve(user1_file.id)
|
||||
assert retrieved_file.id == user1_file.id
|
||||
|
||||
# User 1 cannot retrieve user2's file
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
client.files.retrieve(user2_file.id)
|
||||
|
||||
# User 1 can access their file content
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
content_response = client.files.content(user1_file.id)
|
||||
if isinstance(content_response, str):
|
||||
content = bytes(content_response, "utf-8")
|
||||
else:
|
||||
content = content_response.content
|
||||
assert content == test_content_1
|
||||
|
||||
# User 1 cannot access user2's file content
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
client.files.content(user2_file.id)
|
||||
|
||||
# User 1 can delete their own file
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
delete_response = client.files.delete(user1_file.id)
|
||||
assert delete_response.deleted is True
|
||||
|
||||
# User 1 cannot delete user2's file
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
client.files.delete(user2_file.id)
|
||||
|
||||
# User 2 can still access their file after user1's file is deleted
|
||||
mock_get_authenticated_user.return_value = user2
|
||||
retrieved_file = client.files.retrieve(user2_file.id)
|
||||
assert retrieved_file.id == user2_file.id
|
||||
|
||||
# Cleanup user2's file
|
||||
mock_get_authenticated_user.return_value = user2
|
||||
client.files.delete(user2_file.id)
|
||||
|
||||
except Exception as e:
|
||||
# Cleanup in case of failure
|
||||
try:
|
||||
mock_get_authenticated_user.return_value = user1
|
||||
client.files.delete(user1_file.id)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
mock_get_authenticated_user.return_value = user2
|
||||
client.files.delete(user2_file.id)
|
||||
except Exception:
|
||||
pass
|
||||
raise e
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
def test_files_authentication_shared_attributes(mock_get_authenticated_user, compat_client, client_with_models):
|
||||
"""Test access control with users having identical attributes."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient")
|
||||
if not isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Authentication tests require LlamaStackAsLibraryClient (library mode)")
|
||||
|
||||
client = compat_client
|
||||
|
||||
# Create users with identical attributes (required for default policy)
|
||||
user_a = User("user-a", {"roles": ["user"], "teams": ["shared-team"]})
|
||||
user_b = User("user-b", {"roles": ["user"], "teams": ["shared-team"]})
|
||||
|
||||
# User A uploads a file
|
||||
mock_get_authenticated_user.return_value = user_a
|
||||
test_content = b"Shared attributes file content"
|
||||
|
||||
with BytesIO(test_content) as file_buffer:
|
||||
file_buffer.name = "shared_attributes_file.txt"
|
||||
shared_file = client.files.create(file=file_buffer, purpose="assistants")
|
||||
|
||||
try:
|
||||
# User B with identical attributes can access the file
|
||||
mock_get_authenticated_user.return_value = user_b
|
||||
files_list = client.files.list()
|
||||
file_ids = [f.id for f in files_list.data]
|
||||
|
||||
# User B should be able to see the file due to identical attributes
|
||||
assert shared_file.id in file_ids
|
||||
|
||||
# User B can retrieve file info
|
||||
retrieved_file = client.files.retrieve(shared_file.id)
|
||||
assert retrieved_file.id == shared_file.id
|
||||
|
||||
# User B can access file content
|
||||
content_response = client.files.content(shared_file.id)
|
||||
if isinstance(content_response, str):
|
||||
content = bytes(content_response, "utf-8")
|
||||
else:
|
||||
content = content_response.content
|
||||
assert content == test_content
|
||||
|
||||
# Cleanup
|
||||
mock_get_authenticated_user.return_value = user_a
|
||||
client.files.delete(shared_file.id)
|
||||
|
||||
except Exception as e:
|
||||
# Cleanup in case of failure
|
||||
try:
|
||||
mock_get_authenticated_user.return_value = user_a
|
||||
client.files.delete(shared_file.id)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
mock_get_authenticated_user.return_value = user_b
|
||||
client.files.delete(shared_file.id)
|
||||
except Exception:
|
||||
pass
|
||||
raise e
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||
def test_files_authentication_anonymous_access(mock_get_authenticated_user, compat_client, client_with_models):
|
||||
"""Test anonymous user behavior when no authentication is present."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI):
|
||||
pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient")
|
||||
if not isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("Authentication tests require LlamaStackAsLibraryClient (library mode)")
|
||||
|
||||
client = compat_client
|
||||
|
||||
# Simulate anonymous user (no authentication)
|
||||
mock_get_authenticated_user.return_value = None
|
||||
|
||||
test_content = b"Anonymous file content"
|
||||
|
||||
with BytesIO(test_content) as file_buffer:
|
||||
file_buffer.name = "anonymous_file.txt"
|
||||
anonymous_file = client.files.create(file=file_buffer, purpose="assistants")
|
||||
|
||||
try:
|
||||
# Anonymous user should be able to access their own uploaded file
|
||||
files_list = client.files.list()
|
||||
file_ids = [f.id for f in files_list.data]
|
||||
assert anonymous_file.id in file_ids
|
||||
|
||||
# Can retrieve file info
|
||||
retrieved_file = client.files.retrieve(anonymous_file.id)
|
||||
assert retrieved_file.id == anonymous_file.id
|
||||
|
||||
# Can access file content
|
||||
content_response = client.files.content(anonymous_file.id)
|
||||
if isinstance(content_response, str):
|
||||
content = bytes(content_response, "utf-8")
|
||||
else:
|
||||
content = content_response.content
|
||||
assert content == test_content
|
||||
|
||||
# Can delete the file
|
||||
delete_response = client.files.delete(anonymous_file.id)
|
||||
assert delete_response.deleted is True
|
||||
|
||||
except Exception as e:
|
||||
# Cleanup in case of failure
|
||||
try:
|
||||
client.files.delete(anonymous_file.id)
|
||||
except Exception:
|
||||
pass
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -257,6 +257,11 @@ def openai_client(client_with_models):
|
|||
return OpenAI(base_url=base_url, api_key="fake")
|
||||
|
||||
|
||||
@pytest.fixture(params=["openai_client", "llama_stack_client"])
|
||||
def compat_client(request):
|
||||
return request.getfixturevalue(request.param)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def cleanup_server_process(request):
|
||||
"""Cleanup server process at the end of the test session."""
|
||||
|
|
|
|||
|
|
@ -123,14 +123,14 @@ class TestPostTraining:
|
|||
logger.info(f"Job artifacts: {artifacts}")
|
||||
|
||||
# TODO: Fix these tests to properly represent the Jobs API in training
|
||||
# @pytest.mark.asyncio
|
||||
#
|
||||
# async def test_get_training_jobs(self, post_training_stack):
|
||||
# post_training_impl = post_training_stack
|
||||
# jobs_list = await post_training_impl.get_training_jobs()
|
||||
# assert isinstance(jobs_list, list)
|
||||
# assert jobs_list[0].job_uuid == "1234"
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
#
|
||||
# async def test_get_training_job_status(self, post_training_stack):
|
||||
# post_training_impl = post_training_stack
|
||||
# job_status = await post_training_impl.get_training_job_status("1234")
|
||||
|
|
@ -139,7 +139,7 @@ class TestPostTraining:
|
|||
# assert job_status.status == JobStatus.completed
|
||||
# assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
#
|
||||
# async def test_get_training_job_artifacts(self, post_training_stack):
|
||||
# post_training_impl = post_training_stack
|
||||
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||
|
|
|
|||
|
|
@ -5,41 +5,183 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import Agent
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="telemetry is not stable")
|
||||
def test_agent_query_spans(llama_stack_client, text_model_id):
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_telemetry_data(llama_stack_client, text_model_id):
|
||||
"""Setup fixture that creates telemetry data before tests run."""
|
||||
agent = Agent(llama_stack_client, model=text_model_id, instructions="You are a helpful assistant")
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me a sentence that contains the word: hello",
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
stream=False,
|
||||
|
||||
session_id = agent.create_session(f"test-setup-session-{uuid4()}")
|
||||
|
||||
messages = [
|
||||
"What is 2 + 2?",
|
||||
"Tell me a short joke",
|
||||
]
|
||||
|
||||
for msg in messages:
|
||||
agent.create_turn(
|
||||
messages=[{"role": "user", "content": msg}],
|
||||
session_id=session_id,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
for i in range(2):
|
||||
llama_stack_client.inference.chat_completion(
|
||||
model_id=text_model_id, messages=[{"role": "user", "content": f"Test trace {i}"}]
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < 30:
|
||||
traces = llama_stack_client.telemetry.query_traces(limit=10)
|
||||
if len(traces) >= 4:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
if len(traces) < 4:
|
||||
pytest.fail(f"Failed to create sufficient telemetry data after 30s. Got {len(traces)} traces.")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def test_query_traces_basic(llama_stack_client):
|
||||
"""Test basic trace querying functionality with proper data validation."""
|
||||
all_traces = llama_stack_client.telemetry.query_traces(limit=5)
|
||||
|
||||
assert isinstance(all_traces, list), "Should return a list of traces"
|
||||
assert len(all_traces) >= 4, "Should have at least 4 traces from setup"
|
||||
|
||||
# Verify trace structure and data quality
|
||||
first_trace = all_traces[0]
|
||||
assert hasattr(first_trace, "trace_id"), "Trace should have trace_id"
|
||||
assert hasattr(first_trace, "start_time"), "Trace should have start_time"
|
||||
assert hasattr(first_trace, "root_span_id"), "Trace should have root_span_id"
|
||||
|
||||
# Validate trace_id is a valid UUID format
|
||||
assert isinstance(first_trace.trace_id, str) and len(first_trace.trace_id) > 0, (
|
||||
"trace_id should be non-empty string"
|
||||
)
|
||||
|
||||
# Wait for the span to be logged
|
||||
time.sleep(2)
|
||||
# Validate start_time format and not in the future
|
||||
now = datetime.now(UTC)
|
||||
if isinstance(first_trace.start_time, str):
|
||||
trace_time = datetime.fromisoformat(first_trace.start_time.replace("Z", "+00:00"))
|
||||
else:
|
||||
# start_time is already a datetime object
|
||||
trace_time = first_trace.start_time
|
||||
if trace_time.tzinfo is None:
|
||||
trace_time = trace_time.replace(tzinfo=UTC)
|
||||
|
||||
agent_logs = []
|
||||
# Ensure trace time is not in the future (but allow any age in the past for persistent test data)
|
||||
time_diff = (now - trace_time).total_seconds()
|
||||
assert time_diff >= 0, f"Trace start_time should not be in the future, got {time_diff}s"
|
||||
|
||||
for span in llama_stack_client.telemetry.query_spans(
|
||||
attribute_filters=[
|
||||
{"key": "session_id", "op": "eq", "value": session_id},
|
||||
],
|
||||
attributes_to_return=["input", "output"],
|
||||
):
|
||||
if span.attributes["output"] != "no shields":
|
||||
agent_logs.append(span.attributes)
|
||||
# Validate root_span_id exists and is non-empty
|
||||
assert isinstance(first_trace.root_span_id, str) and len(first_trace.root_span_id) > 0, (
|
||||
"root_span_id should be non-empty string"
|
||||
)
|
||||
|
||||
assert len(agent_logs) == 1
|
||||
assert "Give me a sentence that contains the word: hello" in agent_logs[0]["input"]
|
||||
assert "hello" in agent_logs[0]["output"].lower()
|
||||
# Test querying specific trace by ID
|
||||
specific_trace = llama_stack_client.telemetry.get_trace(trace_id=first_trace.trace_id)
|
||||
assert specific_trace.trace_id == first_trace.trace_id, "Retrieved trace should match requested ID"
|
||||
assert specific_trace.start_time == first_trace.start_time, "Retrieved trace should have same start_time"
|
||||
assert specific_trace.root_span_id == first_trace.root_span_id, "Retrieved trace should have same root_span_id"
|
||||
|
||||
# Test pagination with proper validation
|
||||
recent_traces = llama_stack_client.telemetry.query_traces(limit=3, offset=0)
|
||||
assert len(recent_traces) <= 3, "Should return at most 3 traces when limit=3"
|
||||
assert len(recent_traces) >= 1, "Should return at least 1 trace"
|
||||
|
||||
# Verify all traces have required fields
|
||||
for trace in recent_traces:
|
||||
assert hasattr(trace, "trace_id") and trace.trace_id, "Each trace should have non-empty trace_id"
|
||||
assert hasattr(trace, "start_time") and trace.start_time, "Each trace should have non-empty start_time"
|
||||
assert hasattr(trace, "root_span_id") and trace.root_span_id, "Each trace should have non-empty root_span_id"
|
||||
|
||||
|
||||
def test_query_spans_basic(llama_stack_client):
|
||||
"""Test basic span querying functionality with proper validation."""
|
||||
spans = llama_stack_client.telemetry.query_spans(attribute_filters=[], attributes_to_return=[])
|
||||
|
||||
assert isinstance(spans, list), "Should return a list of spans"
|
||||
assert len(spans) >= 1, "Should have at least one span from setup"
|
||||
|
||||
# Verify span structure and data quality
|
||||
first_span = spans[0]
|
||||
required_attrs = ["span_id", "name", "trace_id"]
|
||||
for attr in required_attrs:
|
||||
assert hasattr(first_span, attr), f"Span should have {attr} attribute"
|
||||
assert getattr(first_span, attr), f"Span {attr} should not be empty"
|
||||
|
||||
# Validate span data types and content
|
||||
assert isinstance(first_span.span_id, str) and len(first_span.span_id) > 0, "span_id should be non-empty string"
|
||||
assert isinstance(first_span.name, str) and len(first_span.name) > 0, "span name should be non-empty string"
|
||||
assert isinstance(first_span.trace_id, str) and len(first_span.trace_id) > 0, "trace_id should be non-empty string"
|
||||
|
||||
# Verify span belongs to a valid trace (test with traces we know exist)
|
||||
all_traces = llama_stack_client.telemetry.query_traces(limit=10)
|
||||
trace_ids = {t.trace_id for t in all_traces}
|
||||
if first_span.trace_id in trace_ids:
|
||||
trace = llama_stack_client.telemetry.get_trace(trace_id=first_span.trace_id)
|
||||
assert trace is not None, "Should be able to retrieve trace for valid trace_id"
|
||||
assert trace.trace_id == first_span.trace_id, "Trace ID should match span's trace_id"
|
||||
|
||||
# Test with span filtering and validate results
|
||||
filtered_spans = llama_stack_client.telemetry.query_spans(
|
||||
attribute_filters=[{"key": "name", "op": "eq", "value": first_span.name}],
|
||||
attributes_to_return=["name", "span_id"],
|
||||
)
|
||||
assert isinstance(filtered_spans, list), "Should return a list with span name filter"
|
||||
|
||||
# Validate filtered spans if filtering works
|
||||
if len(filtered_spans) > 0:
|
||||
for span in filtered_spans:
|
||||
assert hasattr(span, "name"), "Filtered spans should have name attribute"
|
||||
assert hasattr(span, "span_id"), "Filtered spans should have span_id attribute"
|
||||
assert span.name == first_span.name, "Filtered spans should match the filter criteria"
|
||||
assert isinstance(span.span_id, str) and len(span.span_id) > 0, "Filtered span_id should be valid"
|
||||
|
||||
# Test that all spans have consistent structure
|
||||
for span in spans:
|
||||
for attr in required_attrs:
|
||||
assert hasattr(span, attr) and getattr(span, attr), f"All spans should have non-empty {attr}"
|
||||
|
||||
|
||||
def test_telemetry_pagination(llama_stack_client):
|
||||
"""Test pagination in telemetry queries."""
|
||||
# Get total count of traces
|
||||
all_traces = llama_stack_client.telemetry.query_traces(limit=20)
|
||||
total_count = len(all_traces)
|
||||
assert total_count >= 4, "Should have at least 4 traces from setup"
|
||||
|
||||
# Test trace pagination
|
||||
page1 = llama_stack_client.telemetry.query_traces(limit=2, offset=0)
|
||||
page2 = llama_stack_client.telemetry.query_traces(limit=2, offset=2)
|
||||
|
||||
assert len(page1) == 2, "First page should have exactly 2 traces"
|
||||
assert len(page2) >= 1, "Second page should have at least 1 trace"
|
||||
|
||||
# Verify no overlap between pages
|
||||
page1_ids = {t.trace_id for t in page1}
|
||||
page2_ids = {t.trace_id for t in page2}
|
||||
assert len(page1_ids.intersection(page2_ids)) == 0, "Pages should contain different traces"
|
||||
|
||||
# Test ordering
|
||||
ordered_traces = llama_stack_client.telemetry.query_traces(limit=5, order_by=["start_time"])
|
||||
assert len(ordered_traces) >= 4, "Should have at least 4 traces for ordering test"
|
||||
|
||||
# Verify ordering by start_time
|
||||
for i in range(len(ordered_traces) - 1):
|
||||
current_time = ordered_traces[i].start_time
|
||||
next_time = ordered_traces[i + 1].start_time
|
||||
assert current_time <= next_time, f"Traces should be ordered by start_time: {current_time} > {next_time}"
|
||||
|
||||
# Test limit behavior
|
||||
limited = llama_stack_client.telemetry.query_traces(limit=3)
|
||||
assert len(limited) == 3, "Should return exactly 3 traces when limit=3"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,17 @@
|
|||
# Llama Stack Unit Tests
|
||||
|
||||
## Unit Tests
|
||||
|
||||
Unit tests verify individual components and functions in isolation. They are fast, reliable, and don't require external services.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. **Python Environment**: Ensure you have Python 3.12+ installed
|
||||
2. **uv Package Manager**: Install `uv` if not already installed
|
||||
|
||||
You can run the unit tests by running:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
./scripts/unit-tests.sh [PYTEST_ARGS]
|
||||
```
|
||||
|
||||
|
|
@ -19,3 +27,21 @@ If you'd like to run for a non-default version of Python (currently 3.12), pass
|
|||
source .venv/bin/activate
|
||||
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
|
||||
```
|
||||
|
||||
### Test Configuration
|
||||
|
||||
- **Test Discovery**: Tests are automatically discovered in the `tests/unit/` directory
|
||||
- **Async Support**: Tests use `--asyncio-mode=auto` for automatic async test handling
|
||||
- **Coverage**: Tests generate coverage reports in `htmlcov/` directory
|
||||
- **Python Version**: Defaults to Python 3.12, but can be overridden with `PYTHON_VERSION` environment variable
|
||||
|
||||
### Coverage Reports
|
||||
|
||||
After running tests, you can view coverage reports:
|
||||
|
||||
```bash
|
||||
# Open HTML coverage report in browser
|
||||
open htmlcov/index.html # macOS
|
||||
xdg-open htmlcov/index.html # Linux
|
||||
start htmlcov/index.html # Windows
|
||||
```
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import pytest
|
|||
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.distribution.access_control.access_control import default_policy
|
||||
from llama_stack.providers.inline.files.localfs import (
|
||||
LocalfsFilesImpl,
|
||||
LocalfsFilesImplConfig,
|
||||
|
|
@ -38,7 +39,7 @@ async def files_provider(tmp_path):
|
|||
storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix())
|
||||
)
|
||||
|
||||
provider = LocalfsFilesImpl(config)
|
||||
provider = LocalfsFilesImpl(config, default_policy())
|
||||
await provider.initialize()
|
||||
yield provider
|
||||
|
||||
|
|
|
|||
|
|
@ -4,14 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
SystemMessageBehavior,
|
||||
ToolCall,
|
||||
ToolConfig,
|
||||
UserMessage,
|
||||
|
|
@ -25,264 +24,266 @@ from llama_stack.models.llama.datatypes import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
chat_completion_request_to_prompt,
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
MODEL = "Llama3.1-8B-Instruct"
|
||||
MODEL3_2 = "Llama3.2-3B-Instruct"
|
||||
|
||||
|
||||
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
asyncio.get_running_loop().set_debug(False)
|
||||
async def test_system_default():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert messages[-1].content == content
|
||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
async def test_system_default(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
|
||||
async def test_system_builtin_only(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
||||
async def test_system_builtin_only():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert messages[-1].content == content
|
||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
async def test_system_custom_only(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
|
||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
async def test_system_custom_only():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 3
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
async def test_system_custom_and_builtin(self):
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 3)
|
||||
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
||||
|
||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_completion_message_encoding(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL3_2,
|
||||
messages=[
|
||||
UserMessage(content="hello"),
|
||||
CompletionMessage(
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_name="custom1",
|
||||
arguments={"param1": "value1"},
|
||||
call_id="123",
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
||||
)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
self.assertIn('[custom1(param1="value1")]', prompt)
|
||||
|
||||
request.model = MODEL
|
||||
request.tool_config.tool_prompt_format = ToolPromptFormat.json
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
self.assertIn(
|
||||
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
|
||||
prompt,
|
||||
)
|
||||
|
||||
async def test_user_provided_system_message(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_repalce_system_message_behavior_builtin_tools(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
async def test_system_custom_and_builtin():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
self.assertIn("Environment: ipython", messages[0].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 3
|
||||
|
||||
async def test_repalce_system_message_behavior_custom_tools(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_completion_message_encoding():
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL3_2,
|
||||
messages=[
|
||||
UserMessage(content="hello"),
|
||||
CompletionMessage(
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_name="custom1",
|
||||
arguments={"param1": "value1"},
|
||||
call_id="123",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
||||
self.assertIn("Environment: ipython", messages[0].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools_with_template(self):
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate {{ function_description }}"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="python_list",
|
||||
system_message_behavior="replace",
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
||||
)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
assert '[custom1(param1="value1")]' in prompt
|
||||
|
||||
self.assertEqual(len(messages), 2, messages)
|
||||
self.assertIn("Environment: ipython", messages[0].content)
|
||||
self.assertIn("You are a pirate", messages[0].content)
|
||||
# function description is present in the system prompt
|
||||
self.assertIn('"name": "custom1"', messages[0].content)
|
||||
self.assertEqual(messages[-1].content, content)
|
||||
request.model = MODEL
|
||||
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
|
||||
|
||||
|
||||
async def test_user_provided_system_message():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_builtin_tools():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=ToolPromptFormat.python_list,
|
||||
system_message_behavior=SystemMessageBehavior.replace,
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
assert len(messages) == 2
|
||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=ToolPromptFormat.python_list,
|
||||
system_message_behavior=SystemMessageBehavior.replace,
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools_with_template():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate {{ function_description }}"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
parameters={
|
||||
"param1": ToolParamDefinition(
|
||||
param_type="str",
|
||||
description="param1 description",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=ToolPromptFormat.python_list,
|
||||
system_message_behavior=SystemMessageBehavior.replace,
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
|
||||
# function description is present in the system prompt
|
||||
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
|
||||
assert messages[-1].content == content
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
|
||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||
|
|
@ -24,59 +23,61 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
|||
)
|
||||
|
||||
|
||||
class PromptTemplateTests(unittest.TestCase):
|
||||
def check_generator_output(self, generator):
|
||||
for example in generator.data_examples():
|
||||
pt = generator.gen(example)
|
||||
text = pt.render()
|
||||
# print(text) # debugging
|
||||
if not example:
|
||||
continue
|
||||
for tool in example:
|
||||
assert tool.tool_name in text
|
||||
|
||||
def test_system_default(self):
|
||||
generator = SystemDefaultGenerator()
|
||||
today = datetime.now().strftime("%d %B %Y")
|
||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
||||
|
||||
def test_system_builtin_only(self):
|
||||
generator = BuiltinToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
"""
|
||||
)
|
||||
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
||||
|
||||
def test_system_custom_only(self):
|
||||
self.maxDiff = None
|
||||
generator = JsonCustomToolGenerator()
|
||||
self.check_generator_output(generator)
|
||||
|
||||
def test_system_custom_function_tag(self):
|
||||
self.maxDiff = None
|
||||
generator = FunctionTagCustomToolGenerator()
|
||||
self.check_generator_output(generator)
|
||||
|
||||
def test_llama_3_2_system_zero_shot(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
self.check_generator_output(generator)
|
||||
|
||||
def test_llama_3_2_provided_system_prompt(self):
|
||||
generator = PythonListCustomToolGenerator()
|
||||
user_system_prompt = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
{{ function_description }}
|
||||
"""
|
||||
)
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example, user_system_prompt)
|
||||
def check_generator_output(generator):
|
||||
for example in generator.data_examples():
|
||||
pt = generator.gen(example)
|
||||
text = pt.render()
|
||||
assert "Overriding message." in text
|
||||
assert '"name": "get_weather"' in text
|
||||
if not example:
|
||||
continue
|
||||
for tool in example:
|
||||
assert tool.tool_name in text
|
||||
|
||||
|
||||
def test_system_default():
|
||||
generator = SystemDefaultGenerator()
|
||||
today = datetime.now().strftime("%d %B %Y")
|
||||
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
|
||||
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
||||
|
||||
|
||||
def test_system_builtin_only():
|
||||
generator = BuiltinToolGenerator()
|
||||
expected_text = textwrap.dedent(
|
||||
"""
|
||||
Environment: ipython
|
||||
Tools: brave_search, wolfram_alpha
|
||||
"""
|
||||
)
|
||||
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
|
||||
|
||||
|
||||
def test_system_custom_only():
|
||||
generator = JsonCustomToolGenerator()
|
||||
check_generator_output(generator)
|
||||
|
||||
|
||||
def test_system_custom_function_tag():
|
||||
generator = FunctionTagCustomToolGenerator()
|
||||
check_generator_output(generator)
|
||||
|
||||
|
||||
def test_llama_3_2_system_zero_shot():
|
||||
generator = PythonListCustomToolGenerator()
|
||||
check_generator_output(generator)
|
||||
|
||||
|
||||
def test_llama_3_2_provided_system_prompt():
|
||||
generator = PythonListCustomToolGenerator()
|
||||
user_system_prompt = textwrap.dedent(
|
||||
"""
|
||||
Overriding message.
|
||||
|
||||
{{ function_description }}
|
||||
"""
|
||||
)
|
||||
example = generator.data_examples()[0]
|
||||
|
||||
pt = generator.gen(example, user_system_prompt)
|
||||
text = pt.render()
|
||||
assert "Overriding message." in text
|
||||
assert '"name": "get_weather"' in text
|
||||
|
|
|
|||
|
|
@ -5,103 +5,110 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig
|
||||
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
|
||||
|
||||
|
||||
class TestNvidiaDatastore(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
|
||||
@pytest.fixture
|
||||
def nvidia_adapter():
|
||||
"""Fixture to set up NvidiaDatasetIOAdapter with mocked requests."""
|
||||
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
|
||||
|
||||
config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
|
||||
)
|
||||
self.adapter = NvidiaDatasetIOAdapter(config)
|
||||
self.make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
|
||||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
|
||||
)
|
||||
adapter = NvidiaDatasetIOAdapter(config)
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
with patch(
|
||||
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
|
||||
) as mock_make_request:
|
||||
yield adapter, mock_make_request
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
def _assert_request(mock_call, expected_method, expected_path, expected_json=None):
|
||||
"""Helper function to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
assert call_args[0][0] == expected_method
|
||||
assert call_args[0][1] == expected_path
|
||||
assert call_args[0][0] == expected_method
|
||||
assert call_args[0][1] == expected_path
|
||||
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
def test_register_dataset(self):
|
||||
self.mock_make_request.return_value = {
|
||||
"id": "dataset-123456",
|
||||
|
||||
def test_register_dataset(nvidia_adapter, run_async):
|
||||
adapter, mock_make_request = nvidia_adapter
|
||||
mock_make_request.return_value = {
|
||||
"id": "dataset-123456",
|
||||
"name": "test-dataset",
|
||||
"namespace": "default",
|
||||
}
|
||||
|
||||
dataset_def = Dataset(
|
||||
identifier="test-dataset",
|
||||
type=ResourceType.dataset,
|
||||
provider_resource_id="",
|
||||
provider_id="",
|
||||
purpose=DatasetPurpose.post_training_messages,
|
||||
source=URIDataSource(uri="https://example.com/data.jsonl"),
|
||||
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
|
||||
)
|
||||
|
||||
run_async(adapter.register_dataset(dataset_def))
|
||||
|
||||
mock_make_request.assert_called_once()
|
||||
_assert_request(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
expected_json={
|
||||
"name": "test-dataset",
|
||||
"namespace": "default",
|
||||
}
|
||||
"files_url": "https://example.com/data.jsonl",
|
||||
"project": "default",
|
||||
"format": "jsonl",
|
||||
"description": "Test dataset description",
|
||||
},
|
||||
)
|
||||
|
||||
dataset_def = Dataset(
|
||||
identifier="test-dataset",
|
||||
type="dataset",
|
||||
provider_resource_id="",
|
||||
provider_id="",
|
||||
purpose=DatasetPurpose.post_training_messages,
|
||||
source=URIDataSource(uri="https://example.com/data.jsonl"),
|
||||
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
|
||||
)
|
||||
|
||||
self.run_async(self.adapter.register_dataset(dataset_def))
|
||||
def test_unregister_dataset(nvidia_adapter, run_async):
|
||||
adapter, mock_make_request = nvidia_adapter
|
||||
mock_make_request.return_value = {
|
||||
"message": "Resource deleted successfully.",
|
||||
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
|
||||
"deleted_at": None,
|
||||
}
|
||||
dataset_id = "test-dataset"
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
expected_json={
|
||||
"name": "test-dataset",
|
||||
"namespace": "default",
|
||||
"files_url": "https://example.com/data.jsonl",
|
||||
"project": "default",
|
||||
"format": "jsonl",
|
||||
"description": "Test dataset description",
|
||||
},
|
||||
)
|
||||
run_async(adapter.unregister_dataset(dataset_id))
|
||||
|
||||
def test_unregister_dataset(self):
|
||||
self.mock_make_request.return_value = {
|
||||
"message": "Resource deleted successfully.",
|
||||
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
|
||||
"deleted_at": None,
|
||||
}
|
||||
dataset_id = "test-dataset"
|
||||
mock_make_request.assert_called_once()
|
||||
_assert_request(mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
|
||||
|
||||
self.run_async(self.adapter.unregister_dataset(dataset_id))
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
|
||||
def test_register_dataset_with_custom_namespace_project(run_async):
|
||||
"""Test with custom namespace and project configuration."""
|
||||
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
|
||||
|
||||
def test_register_dataset_with_custom_namespace_project(self):
|
||||
custom_config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
|
||||
dataset_namespace="custom-namespace",
|
||||
project_id="custom-project",
|
||||
)
|
||||
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
|
||||
custom_config = NvidiaDatasetIOConfig(
|
||||
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
|
||||
dataset_namespace="custom-namespace",
|
||||
project_id="custom-project",
|
||||
)
|
||||
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
|
||||
|
||||
self.mock_make_request.return_value = {
|
||||
with patch(
|
||||
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
|
||||
) as mock_make_request:
|
||||
mock_make_request.return_value = {
|
||||
"id": "dataset-123456",
|
||||
"name": "test-dataset",
|
||||
"namespace": "custom-namespace",
|
||||
|
|
@ -109,7 +116,7 @@ class TestNvidiaDatastore(unittest.TestCase):
|
|||
|
||||
dataset_def = Dataset(
|
||||
identifier="test-dataset",
|
||||
type="dataset",
|
||||
type=ResourceType.dataset,
|
||||
provider_resource_id="",
|
||||
provider_id="",
|
||||
purpose=DatasetPurpose.post_training_messages,
|
||||
|
|
@ -117,11 +124,11 @@ class TestNvidiaDatastore(unittest.TestCase):
|
|||
metadata={"format": "jsonl"},
|
||||
)
|
||||
|
||||
self.run_async(custom_adapter.register_dataset(dataset_def))
|
||||
run_async(custom_adapter.register_dataset(dataset_def))
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
mock_make_request.assert_called_once()
|
||||
_assert_request(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
"/v1/datasets",
|
||||
expected_json={
|
||||
|
|
@ -132,7 +139,3 @@ class TestNvidiaDatastore(unittest.TestCase):
|
|||
"format": "jsonl",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -27,14 +26,13 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
|||
)
|
||||
|
||||
|
||||
class TestNvidiaParameters(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
|
||||
class TestNvidiaParameters:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown(self):
|
||||
"""Setup and teardown for each test method."""
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
|
||||
config = NvidiaPostTrainingConfig(
|
||||
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
||||
)
|
||||
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
|
||||
self.adapter = NvidiaPostTrainingAdapter(config)
|
||||
|
||||
self.make_request_patcher = patch(
|
||||
|
|
@ -48,7 +46,8 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
"updated_at": "2025-03-04T13:07:47.543605",
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
yield
|
||||
|
||||
self.make_request_patcher.stop()
|
||||
|
||||
def _assert_request_params(self, expected_json):
|
||||
|
|
@ -166,8 +165,8 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
|
||||
self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid=required_job_uuid, # Required parameter
|
||||
model=required_model, # Required parameter
|
||||
job_uuid=required_job_uuid,
|
||||
model=required_model,
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
|
|
@ -198,7 +197,6 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
data_config = DataConfig(
|
||||
dataset_id="test-dataset",
|
||||
batch_size=8,
|
||||
# Unsupported parameters
|
||||
shuffle=True,
|
||||
data_format=DatasetFormat.instruct,
|
||||
validation_dataset_id="val-dataset",
|
||||
|
|
@ -207,20 +205,16 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
optimizer_config = OptimizerConfig(
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
# Unsupported parameters
|
||||
optimizer_type=OptimizerType.adam,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
efficiency_config = EfficiencyConfig(
|
||||
enable_activation_checkpointing=True # Unsupported parameter
|
||||
)
|
||||
efficiency_config = EfficiencyConfig(enable_activation_checkpointing=True)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=1,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
# Unsupported parameters
|
||||
efficiency_config=efficiency_config,
|
||||
max_steps_per_epoch=1000,
|
||||
gradient_accumulation_steps=4,
|
||||
|
|
@ -228,7 +222,6 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
dtype="bf16",
|
||||
)
|
||||
|
||||
# Capture warnings
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
|
|
@ -236,7 +229,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="test-job",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="test-dir", # Unsupported parameter
|
||||
checkpoint_dir="test-dir",
|
||||
algorithm_config=LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
|
|
@ -246,8 +239,8 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
),
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={"test": "value"}, # Unsupported parameter
|
||||
hyperparam_search_config={"test": "value"}, # Unsupported parameter
|
||||
logger_config={"test": "value"},
|
||||
hyperparam_search_config={"test": "value"},
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -265,7 +258,6 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
"gradient_accumulation_steps",
|
||||
"max_validation_steps",
|
||||
"dtype",
|
||||
# required unsupported parameters
|
||||
"rank",
|
||||
"apply_lora_to_output",
|
||||
"lora_attn_modules",
|
||||
|
|
@ -273,7 +265,3 @@ class TestNvidiaParameters(unittest.TestCase):
|
|||
]
|
||||
for field in fields:
|
||||
assert any(field in text for text in warning_texts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -5,13 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.post_training.post_training import (
|
||||
DataConfig,
|
||||
DatasetFormat,
|
||||
|
|
@ -22,7 +20,6 @@ from llama_stack.apis.post_training.post_training import (
|
|||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
ListNvidiaPostTrainingJobs,
|
||||
NvidiaPostTrainingAdapter,
|
||||
|
|
@ -32,331 +29,297 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
|||
)
|
||||
|
||||
|
||||
class TestNvidiaPostTraining(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||
@pytest.fixture
|
||||
def nvidia_post_training_adapter():
|
||||
"""Fixture to create and configure the NVIDIA post training adapter."""
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
|
||||
|
||||
config = NvidiaPostTrainingConfig(
|
||||
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
|
||||
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
|
||||
adapter = NvidiaPostTrainingAdapter(config)
|
||||
|
||||
with patch.object(adapter, "_make_request") as mock_make_request:
|
||||
yield adapter, mock_make_request
|
||||
|
||||
|
||||
def _assert_request(mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
if expected_method and expected_path:
|
||||
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
||||
assert call_args[0] == (expected_method, expected_path)
|
||||
else:
|
||||
assert call_args[1]["method"] == expected_method
|
||||
assert call_args[1]["path"] == expected_path
|
||||
|
||||
if expected_params:
|
||||
assert call_args[1]["params"] == expected_params
|
||||
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
|
||||
async def test_supervised_fine_tune(nvidia_post_training_adapter):
|
||||
"""Test the supervised fine-tuning API call."""
|
||||
adapter, mock_make_request = nvidia_post_training_adapter
|
||||
mock_make_request.return_value = {
|
||||
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:06:28.542884",
|
||||
"config": {
|
||||
"schema_version": "1.0",
|
||||
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.569837",
|
||||
"custom_fields": {},
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model_path": "llama-3_1-8b-instruct",
|
||||
"training_types": [],
|
||||
"finetuning_types": ["lora"],
|
||||
"precision": "bf16",
|
||||
"num_gpus": 4,
|
||||
"num_nodes": 1,
|
||||
"micro_batch_size": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
"max_seq_length": 4096,
|
||||
},
|
||||
"dataset": {
|
||||
"schema_version": "1.0",
|
||||
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.542660",
|
||||
"custom_fields": {},
|
||||
"name": "sample-basic-test",
|
||||
"version_id": "main",
|
||||
"version_tags": [],
|
||||
},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"alpha": 16},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "created",
|
||||
"project": "default",
|
||||
"custom_fields": {},
|
||||
"ownership": {"created_by": "me", "access_policies": {}},
|
||||
}
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True):
|
||||
warnings.simplefilter("always")
|
||||
training_job = await adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
self.adapter = NvidiaPostTrainingAdapter(config)
|
||||
self.make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
|
||||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
|
||||
# Mock the inference client
|
||||
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
|
||||
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
|
||||
# check the output is a PostTrainingJob
|
||||
assert isinstance(training_job, NvidiaPostTrainingJob)
|
||||
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
self.mock_client = unittest.mock.MagicMock()
|
||||
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
||||
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
||||
self.inference_make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
|
||||
return_value=self.mock_client,
|
||||
)
|
||||
self.inference_make_request_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
self.inference_make_request_patcher.stop()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
|
||||
"""Helper method to verify request details in mock calls."""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
if expected_method and expected_path:
|
||||
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
|
||||
assert call_args[0] == (expected_method, expected_path)
|
||||
else:
|
||||
assert call_args[1]["method"] == expected_method
|
||||
assert call_args[1]["path"] == expected_path
|
||||
|
||||
if expected_params:
|
||||
assert call_args[1]["params"] == expected_params
|
||||
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
def test_supervised_fine_tune(self):
|
||||
"""Test the supervised fine-tuning API call."""
|
||||
self.mock_make_request.return_value = {
|
||||
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:06:28.542884",
|
||||
"config": {
|
||||
"schema_version": "1.0",
|
||||
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.569837",
|
||||
"custom_fields": {},
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model_path": "llama-3_1-8b-instruct",
|
||||
"training_types": [],
|
||||
"finetuning_types": ["lora"],
|
||||
"precision": "bf16",
|
||||
"num_gpus": 4,
|
||||
"num_nodes": 1,
|
||||
"micro_batch_size": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
"max_seq_length": 4096,
|
||||
},
|
||||
"dataset": {
|
||||
"schema_version": "1.0",
|
||||
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
|
||||
"created_at": "2024-12-09T04:06:28.542657",
|
||||
"updated_at": "2024-12-09T04:06:28.542660",
|
||||
"custom_fields": {},
|
||||
"name": "sample-basic-test",
|
||||
"version_id": "main",
|
||||
"version_tags": [],
|
||||
},
|
||||
mock_make_request.assert_called_once()
|
||||
_assert_request(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
"/v1/customization/jobs",
|
||||
expected_json={
|
||||
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"finetuning_type": "lora",
|
||||
"epochs": 2,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0001,
|
||||
"weight_decay": 0.01,
|
||||
"lora": {"alpha": 16},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "created",
|
||||
"project": "default",
|
||||
"custom_fields": {},
|
||||
"ownership": {"created_by": "me", "access_policies": {}},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_supervised_fine_tune_with_qat(nvidia_post_training_adapter):
|
||||
"""Test that QAT configuration raises NotImplementedError."""
|
||||
adapter, mock_make_request = nvidia_post_training_adapter
|
||||
|
||||
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
# This will raise NotImplementedError since QAT is not supported
|
||||
with pytest.raises(NotImplementedError):
|
||||
await adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
|
||||
|
||||
async def test_get_training_job_status(nvidia_post_training_adapter):
|
||||
"""Test getting training job status with different statuses."""
|
||||
adapter, mock_make_request = nvidia_post_training_adapter
|
||||
|
||||
customizer_status_to_job_status = [
|
||||
("running", "in_progress"),
|
||||
("completed", "completed"),
|
||||
("failed", "failed"),
|
||||
("cancelled", "cancelled"),
|
||||
("pending", "scheduled"),
|
||||
("unknown", "scheduled"),
|
||||
]
|
||||
|
||||
for customizer_status, expected_status in customizer_status_to_job_status:
|
||||
mock_make_request.return_value = {
|
||||
"created_at": "2024-12-09T04:06:28.580220",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"status": customizer_status,
|
||||
"steps_completed": 1210,
|
||||
"epochs_completed": 2,
|
||||
"percentage_done": 100.0,
|
||||
"best_epoch": 2,
|
||||
"train_loss": 1.718016266822815,
|
||||
"val_loss": 1.8661999702453613,
|
||||
}
|
||||
|
||||
algorithm_config = LoraFinetuningConfig(
|
||||
type="LoRA",
|
||||
apply_lora_to_mlp=True,
|
||||
apply_lora_to_output=True,
|
||||
alpha=16,
|
||||
rank=16,
|
||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True):
|
||||
warnings.simplefilter("always")
|
||||
training_job = self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
|
||||
# check the output is a PostTrainingJob
|
||||
assert isinstance(training_job, NvidiaPostTrainingJob)
|
||||
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
"/v1/customization/jobs",
|
||||
expected_json={
|
||||
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
"dataset": {"name": "sample-basic-test", "namespace": "default"},
|
||||
"hyperparameters": {
|
||||
"training_type": "sft",
|
||||
"finetuning_type": "lora",
|
||||
"epochs": 2,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0001,
|
||||
"weight_decay": 0.01,
|
||||
"lora": {"alpha": 16},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_supervised_fine_tune_with_qat(self):
|
||||
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||
data_config = DataConfig(
|
||||
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||
)
|
||||
optimizer_config = OptimizerConfig(
|
||||
optimizer_type=OptimizerType.adam,
|
||||
lr=0.0001,
|
||||
weight_decay=0.01,
|
||||
num_warmup_steps=100,
|
||||
)
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
# This will raise NotImplementedError since QAT is not supported
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.run_async(
|
||||
self.adapter.supervised_fine_tune(
|
||||
job_uuid="1234",
|
||||
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=convert_pydantic_to_json_value(training_config),
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
)
|
||||
|
||||
def test_get_training_job_status(self):
|
||||
customizer_status_to_job_status = [
|
||||
("running", "in_progress"),
|
||||
("completed", "completed"),
|
||||
("failed", "failed"),
|
||||
("cancelled", "cancelled"),
|
||||
("pending", "scheduled"),
|
||||
("unknown", "scheduled"),
|
||||
]
|
||||
|
||||
for customizer_status, expected_status in customizer_status_to_job_status:
|
||||
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
|
||||
self.mock_make_request.return_value = {
|
||||
"created_at": "2024-12-09T04:06:28.580220",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"status": customizer_status,
|
||||
"steps_completed": 1210,
|
||||
"epochs_completed": 2,
|
||||
"percentage_done": 100.0,
|
||||
"best_epoch": 2,
|
||||
"train_loss": 1.718016266822815,
|
||||
"val_loss": 1.8661999702453613,
|
||||
}
|
||||
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
||||
|
||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||
assert status.status.value == expected_status
|
||||
assert status.steps_completed == 1210
|
||||
assert status.epochs_completed == 2
|
||||
assert status.percentage_done == 100.0
|
||||
assert status.best_epoch == 2
|
||||
assert status.train_loss == 1.718016266822815
|
||||
assert status.val_loss == 1.8661999702453613
|
||||
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"GET",
|
||||
f"/v1/customization/jobs/{job_id}/status",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
def test_get_training_jobs(self):
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
self.mock_make_request.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": job_id,
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"config": {
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
},
|
||||
"dataset": {"name": "default/sample-basic-test"},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "completed",
|
||||
"project": "default",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
jobs = self.run_async(self.adapter.get_training_jobs())
|
||||
status = await adapter.get_training_job_status(job_uuid=job_id)
|
||||
|
||||
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
|
||||
assert len(jobs.data) == 1
|
||||
job = jobs.data[0]
|
||||
assert job.job_uuid == job_id
|
||||
assert job.status.value == "completed"
|
||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||
assert status.status.value == expected_status
|
||||
# Note: The response object inherits extra fields via ConfigDict(extra="allow")
|
||||
# So these attributes should be accessible using getattr with defaults
|
||||
assert getattr(status, "steps_completed", None) == 1210
|
||||
assert getattr(status, "epochs_completed", None) == 2
|
||||
assert getattr(status, "percentage_done", None) == 100.0
|
||||
assert getattr(status, "best_epoch", None) == 2
|
||||
assert getattr(status, "train_loss", None) == 1.718016266822815
|
||||
assert getattr(status, "val_loss", None) == 1.8661999702453613
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
_assert_request(
|
||||
mock_make_request,
|
||||
"GET",
|
||||
"/v1/customization/jobs",
|
||||
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
|
||||
)
|
||||
|
||||
def test_cancel_training_job(self):
|
||||
self.mock_make_request.return_value = {} # Empty response for successful cancellation
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
|
||||
|
||||
assert result is None
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"POST",
|
||||
f"/v1/customization/jobs/{job_id}/cancel",
|
||||
f"/v1/customization/jobs/{job_id}/status",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
def test_inference_register_model(self):
|
||||
model_id = "default/job-1234"
|
||||
model_type = ModelType.llm
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_id="nvidia",
|
||||
provider_model_id=model_id,
|
||||
provider_resource_id=model_id,
|
||||
model_type=model_type,
|
||||
)
|
||||
result = self.run_async(self.inference_adapter.register_model(model))
|
||||
assert result == model
|
||||
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
||||
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
|
||||
|
||||
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
|
||||
self.run_async(
|
||||
self.inference_adapter.chat_completion(
|
||||
model_id=model_id,
|
||||
messages=[{"role": "user", "content": "Hello, model"}],
|
||||
)
|
||||
)
|
||||
|
||||
mock_chat_completion.assert_called()
|
||||
mock_make_request.reset_mock()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
async def test_get_training_jobs(nvidia_post_training_adapter):
|
||||
"""Test getting list of training jobs."""
|
||||
adapter, mock_make_request = nvidia_post_training_adapter
|
||||
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
mock_make_request.return_value = {
|
||||
"data": [
|
||||
{
|
||||
"id": job_id,
|
||||
"created_at": "2024-12-09T04:06:28.542884",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"config": {
|
||||
"name": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
},
|
||||
"dataset": {"name": "default/sample-basic-test"},
|
||||
"hyperparameters": {
|
||||
"finetuning_type": "lora",
|
||||
"training_type": "sft",
|
||||
"batch_size": 16,
|
||||
"epochs": 2,
|
||||
"learning_rate": 0.0001,
|
||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
||||
},
|
||||
"output_model": "default/job-1234",
|
||||
"status": "completed",
|
||||
"project": "default",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
jobs = await adapter.get_training_jobs()
|
||||
|
||||
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
|
||||
assert len(jobs.data) == 1
|
||||
job = jobs.data[0]
|
||||
assert job.job_uuid == job_id
|
||||
assert job.status.value == "completed"
|
||||
|
||||
mock_make_request.assert_called_once()
|
||||
_assert_request(
|
||||
mock_make_request,
|
||||
"GET",
|
||||
"/v1/customization/jobs",
|
||||
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
|
||||
)
|
||||
|
||||
|
||||
async def test_cancel_training_job(nvidia_post_training_adapter):
|
||||
"""Test canceling a training job."""
|
||||
adapter, mock_make_request = nvidia_post_training_adapter
|
||||
|
||||
mock_make_request.return_value = {} # Empty response for successful cancellation
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
result = await adapter.cancel_training_job(job_uuid=job_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
mock_make_request.assert_called_once()
|
||||
_assert_request(
|
||||
mock_make_request,
|
||||
"POST",
|
||||
f"/v1/customization/jobs/{job_id}/cancel",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
|
||||
|
|
@ -33,7 +32,7 @@ with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
|||
MILVUS_PROVIDER = "milvus"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest.fixture
|
||||
async def mock_milvus_client() -> MagicMock:
|
||||
"""Create a mock Milvus client with common method behaviors."""
|
||||
client = MagicMock()
|
||||
|
|
@ -84,7 +83,7 @@ async def mock_milvus_client() -> MagicMock:
|
|||
return client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest.fixture
|
||||
async def milvus_index(mock_milvus_client):
|
||||
"""Create a MilvusIndex with mocked client."""
|
||||
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
||||
|
|
@ -92,7 +91,6 @@ async def milvus_index(mock_milvus_client):
|
|||
# No real cleanup needed since we're using mocks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
# Setup: collection doesn't exist initially, then exists after creation
|
||||
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||
|
|
@ -108,7 +106,6 @@ async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_m
|
|||
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_vector(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
|
|
@ -125,7 +122,6 @@ async def test_query_chunks_vector(
|
|||
mock_milvus_client.search.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
|
@ -138,7 +134,6 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e
|
|||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
|
|
@ -181,7 +176,6 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl
|
|||
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||
# Test collection deletion
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ def loop():
|
|||
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / "test_sqlite.db")
|
||||
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
|
||||
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank.123")
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
|
@ -110,7 +110,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
|
|||
cur = connection.cursor()
|
||||
|
||||
# Retrieve all chunk IDs to check for duplicates
|
||||
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
|
||||
cur.execute(f"SELECT id FROM [{sqlite_vec_index.metadata_table}]")
|
||||
chunk_ids = [row[0] for row in cur.fetchall()]
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ class TestRagQuery:
|
|||
with pytest.raises(ValueError):
|
||||
RAGQueryConfig(mode="invalid_mode")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_accepts_valid_modes(self):
|
||||
RAGQueryConfig() # Test default (vector)
|
||||
RAGQueryConfig(mode="vector") # Test vector
|
||||
|
|
|
|||
|
|
@ -5,73 +5,86 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
|
||||
|
||||
class TestReplaceEnvVars(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Clear any existing environment variables we'll use in tests
|
||||
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
@pytest.fixture
|
||||
def setup_env_vars():
|
||||
# Clear any existing environment variables we'll use in tests
|
||||
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
# Set up test environment variables
|
||||
os.environ["TEST_VAR"] = "test_value"
|
||||
os.environ["EMPTY_VAR"] = ""
|
||||
os.environ["ZERO_VAR"] = "0"
|
||||
# Set up test environment variables
|
||||
os.environ["TEST_VAR"] = "test_value"
|
||||
os.environ["EMPTY_VAR"] = ""
|
||||
os.environ["ZERO_VAR"] = "0"
|
||||
|
||||
def test_simple_replacement(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value")
|
||||
yield
|
||||
|
||||
def test_default_value_when_not_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.NOT_SET:=default}"), "default")
|
||||
|
||||
def test_default_value_when_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:=default}"), "test_value")
|
||||
|
||||
def test_default_value_when_empty(self):
|
||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=default}"), "default")
|
||||
|
||||
def test_none_value_when_empty(self):
|
||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=}"), None)
|
||||
|
||||
def test_value_when_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:=}"), "test_value")
|
||||
|
||||
def test_empty_var_no_default(self):
|
||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}"), None)
|
||||
|
||||
def test_conditional_value_when_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:+conditional}"), "conditional")
|
||||
|
||||
def test_conditional_value_when_not_set(self):
|
||||
self.assertEqual(replace_env_vars("${env.NOT_SET:+conditional}"), None)
|
||||
|
||||
def test_conditional_value_when_empty(self):
|
||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:+conditional}"), None)
|
||||
|
||||
def test_conditional_value_with_zero(self):
|
||||
self.assertEqual(replace_env_vars("${env.ZERO_VAR:+conditional}"), "conditional")
|
||||
|
||||
def test_mixed_syntax(self):
|
||||
self.assertEqual(
|
||||
replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}"), "test_value and "
|
||||
)
|
||||
self.assertEqual(
|
||||
replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}"), "default and conditional"
|
||||
)
|
||||
|
||||
def test_nested_structures(self):
|
||||
data = {
|
||||
"key1": "${env.TEST_VAR:=default}",
|
||||
"key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"],
|
||||
"key3": {"nested": "${env.NOT_SET:+conditional}"},
|
||||
}
|
||||
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
|
||||
self.assertEqual(replace_env_vars(data), expected)
|
||||
# Cleanup after test
|
||||
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
def test_simple_replacement(setup_env_vars):
|
||||
assert replace_env_vars("${env.TEST_VAR}") == "test_value"
|
||||
|
||||
|
||||
def test_default_value_when_not_set(setup_env_vars):
|
||||
assert replace_env_vars("${env.NOT_SET:=default}") == "default"
|
||||
|
||||
|
||||
def test_default_value_when_set(setup_env_vars):
|
||||
assert replace_env_vars("${env.TEST_VAR:=default}") == "test_value"
|
||||
|
||||
|
||||
def test_default_value_when_empty(setup_env_vars):
|
||||
assert replace_env_vars("${env.EMPTY_VAR:=default}") == "default"
|
||||
|
||||
|
||||
def test_none_value_when_empty(setup_env_vars):
|
||||
assert replace_env_vars("${env.EMPTY_VAR:=}") is None
|
||||
|
||||
|
||||
def test_value_when_set(setup_env_vars):
|
||||
assert replace_env_vars("${env.TEST_VAR:=}") == "test_value"
|
||||
|
||||
|
||||
def test_empty_var_no_default(setup_env_vars):
|
||||
assert replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}") is None
|
||||
|
||||
|
||||
def test_conditional_value_when_set(setup_env_vars):
|
||||
assert replace_env_vars("${env.TEST_VAR:+conditional}") == "conditional"
|
||||
|
||||
|
||||
def test_conditional_value_when_not_set(setup_env_vars):
|
||||
assert replace_env_vars("${env.NOT_SET:+conditional}") is None
|
||||
|
||||
|
||||
def test_conditional_value_when_empty(setup_env_vars):
|
||||
assert replace_env_vars("${env.EMPTY_VAR:+conditional}") is None
|
||||
|
||||
|
||||
def test_conditional_value_with_zero(setup_env_vars):
|
||||
assert replace_env_vars("${env.ZERO_VAR:+conditional}") == "conditional"
|
||||
|
||||
|
||||
def test_mixed_syntax(setup_env_vars):
|
||||
assert replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}") == "test_value and "
|
||||
assert replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}") == "default and conditional"
|
||||
|
||||
|
||||
def test_nested_structures(setup_env_vars):
|
||||
data = {
|
||||
"key1": "${env.TEST_VAR:=default}",
|
||||
"key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"],
|
||||
"key3": {"nested": "${env.NOT_SET:+conditional}"},
|
||||
}
|
||||
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
|
||||
assert replace_env_vars(data) == expected
|
||||
|
|
|
|||
21
uv.lock
generated
21
uv.lock
generated
|
|
@ -1268,6 +1268,23 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/2a/f7/67689245f48b9e79bcd2f3a10a3690cb1918fb99fffd5a623ed2496bca66/litellm-1.74.2-py3-none-any.whl", hash = "sha256:29bb555b45128e4cc696e72921a6ec24e97b14e9b69e86eed6f155124ad629b1", size = 8587065 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "llama-api-client"
|
||||
version = "0.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "distro" },
|
||||
{ name = "httpx" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "sniffio" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d0/78/875de3a16efd0442718ac47cc27319cd80cc5f38e12298e454e08611acc4/llama_api_client-0.1.2.tar.gz", hash = "sha256:709011f2d506009b1b3b3bceea1c84f2a3a7600df1420fb256e680fcd7251387", size = 113695, upload-time = "2025-06-27T19:56:14.057Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/99/08/5d7e6e7e6af5353391376288c200acacebb8e6b156d3636eae598a451673/llama_api_client-0.1.2-py3-none-any.whl", hash = "sha256:8ad6e10726f74b2302bfd766c61c41355a9ecf60f57cde2961882d22af998941", size = 84091, upload-time = "2025-06-27T19:56:12.8Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "llama-stack"
|
||||
version = "0.2.15"
|
||||
|
|
@ -1283,6 +1300,7 @@ dependencies = [
|
|||
{ name = "huggingface-hub" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "jsonschema" },
|
||||
{ name = "llama-api-client" },
|
||||
{ name = "llama-stack-client" },
|
||||
{ name = "openai" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
||||
|
|
@ -1372,6 +1390,7 @@ unit = [
|
|||
{ name = "aiosqlite" },
|
||||
{ name = "blobfile" },
|
||||
{ name = "chardet" },
|
||||
{ name = "coverage" },
|
||||
{ name = "faiss-cpu" },
|
||||
{ name = "litellm" },
|
||||
{ name = "mcp" },
|
||||
|
|
@ -1398,6 +1417,7 @@ requires-dist = [
|
|||
{ name = "jsonschema" },
|
||||
{ name = "llama-stack-client", specifier = ">=0.2.15" },
|
||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.15" },
|
||||
{ name = "llama-api-client", specifier = ">=0.1.2" },
|
||||
{ name = "openai", specifier = ">=1.66" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
||||
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
||||
|
|
@ -1480,6 +1500,7 @@ unit = [
|
|||
{ name = "aiosqlite" },
|
||||
{ name = "blobfile" },
|
||||
{ name = "chardet" },
|
||||
{ name = "coverage" },
|
||||
{ name = "faiss-cpu" },
|
||||
{ name = "litellm" },
|
||||
{ name = "mcp" },
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue