Merge branch 'main' into allow-dynamic-models-nvidia

This commit is contained in:
Matthew Farrellee 2025-07-16 12:53:44 -04:00
commit 6173d7a308
71 changed files with 3107 additions and 2381 deletions

View file

@ -7,7 +7,5 @@ runs:
shell: bash
run: |
docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models
# TODO: rebuild an ollama image with llama-guard3:1b
echo "Verifying Ollama status..."
timeout 30 bash -c 'while ! curl -s -L http://127.0.0.1:11434; do sleep 1 && echo "."; done'
docker exec ollama ollama pull llama-guard3:1b

View file

@ -5,6 +5,10 @@ inputs:
description: The Python version to use
required: false
default: "3.12"
client-version:
description: The llama-stack-client-python version to test against (latest or published)
required: false
default: "latest"
runs:
using: "composite"
steps:
@ -20,8 +24,17 @@ runs:
run: |
uv sync --all-groups
uv pip install ollama faiss-cpu
# always test against the latest version of the client
# TODO: this is not necessarily a good idea. we need to test against both published and latest
# to find out backwards compatibility issues.
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
# Install llama-stack-client-python based on the client-version input
if [ "${{ inputs.client-version }}" = "latest" ]; then
echo "Installing latest llama-stack-client-python from main branch"
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
elif [ "${{ inputs.client-version }}" = "published" ]; then
echo "Installing published llama-stack-client-python from PyPI"
uv pip install llama-stack-client
else
echo "Invalid client-version: ${{ inputs.client-version }}"
exit 1
fi
uv pip install -e .

View file

@ -1,355 +0,0 @@
name: "Run Llama-stack Tests"
on:
#### Temporarily disable PR runs until tests run as intended within mainline.
#TODO Add this back.
#pull_request_target:
# types: ["opened"]
# branches:
# - 'main'
# paths:
# - 'llama_stack/**/*.py'
# - 'tests/**/*.py'
workflow_dispatch:
inputs:
runner:
description: 'GHA Runner Scale Set label to run workflow on.'
required: true
default: "llama-stack-gha-runner-gpu"
checkout_reference:
description: "The branch, tag, or SHA to checkout"
required: true
default: "main"
debug:
description: 'Run debugging steps?'
required: false
default: "true"
sleep_time:
description: '[DEBUG] sleep time for debugging'
required: true
default: "0"
provider_id:
description: 'ID of your provider'
required: true
default: "meta_reference"
model_id:
description: 'Shorthand name for target model ID (llama_3b or llama_8b)'
required: true
default: "llama_3b"
model_override_3b:
description: 'Specify shorthand model for <llama_3b> '
required: false
default: "Llama3.2-3B-Instruct"
model_override_8b:
description: 'Specify shorthand model for <llama_8b> '
required: false
default: "Llama3.1-8B-Instruct"
env:
# ID used for each test's provider config
PROVIDER_ID: "${{ inputs.provider_id || 'meta_reference' }}"
# Path to model checkpoints within EFS volume
MODEL_CHECKPOINT_DIR: "/data/llama"
# Path to directory to run tests from
TESTS_PATH: "${{ github.workspace }}/llama_stack/providers/tests"
# Keep track of a list of model IDs that are valid to use within pytest fixture marks
AVAILABLE_MODEL_IDs: "llama_3b llama_8b"
# Shorthand name for model ID, used in pytest fixture marks
MODEL_ID: "${{ inputs.model_id || 'llama_3b' }}"
# Override the `llama_3b` / `llama_8b' models, else use the default.
LLAMA_3B_OVERRIDE: "${{ inputs.model_override_3b || 'Llama3.2-3B-Instruct' }}"
LLAMA_8B_OVERRIDE: "${{ inputs.model_override_8b || 'Llama3.1-8B-Instruct' }}"
# Defines which directories in TESTS_PATH to exclude from the test loop
EXCLUDED_DIRS: "__pycache__"
# Defines the output xml reports generated after a test is run
REPORTS_GEN: ""
jobs:
execute_workflow:
name: Execute workload on Self-Hosted GPU k8s runner
permissions:
pull-requests: write
defaults:
run:
shell: bash
runs-on: ${{ inputs.runner != '' && inputs.runner || 'llama-stack-gha-runner-gpu' }}
if: always()
steps:
##############################
#### INITIAL DEBUG CHECKS ####
##############################
- name: "[DEBUG] Check content of the EFS mount"
id: debug_efs_volume
continue-on-error: true
if: inputs.debug == 'true'
run: |
echo "========= Content of the EFS mount ============="
ls -la ${{ env.MODEL_CHECKPOINT_DIR }}
- name: "[DEBUG] Get runner container OS information"
id: debug_os_info
if: ${{ inputs.debug == 'true' }}
run: |
cat /etc/os-release
- name: "[DEBUG] Print environment variables"
id: debug_env_vars
if: ${{ inputs.debug == 'true' }}
run: |
echo "PROVIDER_ID = ${PROVIDER_ID}"
echo "MODEL_CHECKPOINT_DIR = ${MODEL_CHECKPOINT_DIR}"
echo "AVAILABLE_MODEL_IDs = ${AVAILABLE_MODEL_IDs}"
echo "MODEL_ID = ${MODEL_ID}"
echo "LLAMA_3B_OVERRIDE = ${LLAMA_3B_OVERRIDE}"
echo "LLAMA_8B_OVERRIDE = ${LLAMA_8B_OVERRIDE}"
echo "EXCLUDED_DIRS = ${EXCLUDED_DIRS}"
echo "REPORTS_GEN = ${REPORTS_GEN}"
############################
#### MODEL INPUT CHECKS ####
############################
- name: "Check if env.model_id is valid"
id: check_model_id
run: |
if [[ " ${AVAILABLE_MODEL_IDs[@]} " =~ " ${MODEL_ID} " ]]; then
echo "Model ID '${MODEL_ID}' is valid."
else
echo "Model ID '${MODEL_ID}' is invalid. Terminating workflow."
exit 1
fi
#######################
#### CODE CHECKOUT ####
#######################
- name: "Checkout 'meta-llama/llama-stack' repository"
id: checkout_repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
ref: ${{ inputs.branch }}
- name: "[DEBUG] Content of the repository after checkout"
id: debug_content_after_checkout
if: ${{ inputs.debug == 'true' }}
run: |
ls -la ${GITHUB_WORKSPACE}
##########################################################
#### OPTIONAL SLEEP DEBUG ####
# #
# Use to "exec" into the test k8s POD and run tests #
# manually to identify what dependencies are being used. #
# #
##########################################################
- name: "[DEBUG] sleep"
id: debug_sleep
if: ${{ inputs.debug == 'true' && inputs.sleep_time != '' }}
run: |
sleep ${{ inputs.sleep_time }}
############################
#### UPDATE SYSTEM PATH ####
############################
- name: "Update path: execute"
id: path_update_exec
run: |
# .local/bin is needed for certain libraries installed below to be recognized
# when calling their executable to install sub-dependencies
mkdir -p ${HOME}/.local/bin
echo "${HOME}/.local/bin" >> "$GITHUB_PATH"
#####################################
#### UPDATE CHECKPOINT DIRECTORY ####
#####################################
- name: "Update checkpoint directory"
id: checkpoint_update
run: |
echo "Checkpoint directory: ${MODEL_CHECKPOINT_DIR}/$LLAMA_3B_OVERRIDE"
if [ "${MODEL_ID}" = "llama_3b" ] && [ -d "${MODEL_CHECKPOINT_DIR}/${LLAMA_3B_OVERRIDE}" ]; then
echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/${LLAMA_3B_OVERRIDE}" >> "$GITHUB_ENV"
elif [ "${MODEL_ID}" = "llama_8b" ] && [ -d "${MODEL_CHECKPOINT_DIR}/${LLAMA_8B_OVERRIDE}" ]; then
echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/${LLAMA_8B_OVERRIDE}" >> "$GITHUB_ENV"
else
echo "MODEL_ID & LLAMA_*B_OVERRIDE are not a valid pairing. Terminating workflow."
exit 1
fi
- name: "[DEBUG] Checkpoint update check"
id: debug_checkpoint_update
if: ${{ inputs.debug == 'true' }}
run: |
echo "MODEL_CHECKPOINT_DIR (after update) = ${MODEL_CHECKPOINT_DIR}"
##################################
#### DEPENDENCY INSTALLATIONS ####
##################################
- name: "Installing 'apt' required packages"
id: install_apt
run: |
echo "[STEP] Installing 'apt' required packages"
sudo apt update -y
sudo apt install -y python3 python3-pip npm wget
- name: "Installing packages with 'curl'"
id: install_curl
run: |
curl -fsSL https://ollama.com/install.sh | sh
- name: "Installing packages with 'wget'"
id: install_wget
run: |
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
chmod +x Miniconda3-latest-Linux-x86_64.sh
./Miniconda3-latest-Linux-x86_64.sh -b install -c pytorch -c nvidia faiss-gpu=1.9.0
# Add miniconda3 bin to system path
echo "${HOME}/miniconda3/bin" >> "$GITHUB_PATH"
- name: "Installing packages with 'npm'"
id: install_npm_generic
run: |
sudo npm install -g junit-merge
- name: "Installing pip dependencies"
id: install_pip_generic
run: |
echo "[STEP] Installing 'llama-stack' models"
pip install -U pip setuptools
pip install -r requirements.txt
pip install -e .
pip install -U \
torch torchvision \
pytest pytest_asyncio \
fairscale lm-format-enforcer \
zmq chardet pypdf \
pandas sentence_transformers together \
aiosqlite
- name: "Installing packages with conda"
id: install_conda_generic
run: |
conda install -q -c pytorch -c nvidia faiss-gpu=1.9.0
#############################################################
#### TESTING TO BE DONE FOR BOTH PRS AND MANUAL DISPATCH ####
#############################################################
- name: "Run Tests: Loop"
id: run_tests_loop
working-directory: "${{ github.workspace }}"
run: |
pattern=""
for dir in llama_stack/providers/tests/*; do
if [ -d "$dir" ]; then
dir_name=$(basename "$dir")
if [[ ! " $EXCLUDED_DIRS " =~ " $dir_name " ]]; then
for file in "$dir"/test_*.py; do
test_name=$(basename "$file")
new_file="result-${dir_name}-${test_name}.xml"
if torchrun $(which pytest) -s -v ${TESTS_PATH}/${dir_name}/${test_name} -m "${PROVIDER_ID} and ${MODEL_ID}" \
--junitxml="${{ github.workspace }}/${new_file}"; then
echo "Ran test: ${test_name}"
else
echo "Did NOT run test: ${test_name}"
fi
pattern+="${new_file} "
done
fi
fi
done
echo "REPORTS_GEN=$pattern" >> "$GITHUB_ENV"
- name: "Test Summary: Merge"
id: test_summary_merge
working-directory: "${{ github.workspace }}"
run: |
echo "Merging the following test result files: ${REPORTS_GEN}"
# Defaults to merging them into 'merged-test-results.xml'
junit-merge ${{ env.REPORTS_GEN }}
############################################
#### AUTOMATIC TESTING ON PULL REQUESTS ####
############################################
#### Run tests ####
- name: "PR - Run Tests"
id: pr_run_tests
working-directory: "${{ github.workspace }}"
if: github.event_name == 'pull_request_target'
run: |
echo "[STEP] Running PyTest tests at 'GITHUB_WORKSPACE' path: ${GITHUB_WORKSPACE} | path: ${{ github.workspace }}"
# (Optional) Add more tests here.
# Merge test results with 'merged-test-results.xml' from above.
# junit-merge <new-test-results> merged-test-results.xml
#### Create test summary ####
- name: "PR - Test Summary"
id: pr_test_summary_create
if: github.event_name == 'pull_request_target'
uses: test-summary/action@31493c76ec9e7aa675f1585d3ed6f1da69269a86 # v2.4
with:
paths: "${{ github.workspace }}/merged-test-results.xml"
output: test-summary.md
- name: "PR - Upload Test Summary"
id: pr_test_summary_upload
if: github.event_name == 'pull_request_target'
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: test-summary
path: test-summary.md
#### Update PR request ####
- name: "PR - Update comment"
id: pr_update_comment
if: github.event_name == 'pull_request_target'
uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b # v3.0.1
with:
filePath: test-summary.md
########################
#### MANUAL TESTING ####
########################
#### Run tests ####
- name: "Manual - Run Tests: Prep"
id: manual_run_tests
working-directory: "${{ github.workspace }}"
if: github.event_name == 'workflow_dispatch'
run: |
echo "[STEP] Running PyTest tests at 'GITHUB_WORKSPACE' path: ${{ github.workspace }}"
#TODO Use this when collection errors are resolved
# pytest -s -v -m "${PROVIDER_ID} and ${MODEL_ID}" --junitxml="${{ github.workspace }}/merged-test-results.xml"
# (Optional) Add more tests here.
# Merge test results with 'merged-test-results.xml' from above.
# junit-merge <new-test-results> merged-test-results.xml
#### Create test summary ####
- name: "Manual - Test Summary"
id: manual_test_summary
if: always() && github.event_name == 'workflow_dispatch'
uses: test-summary/action@31493c76ec9e7aa675f1585d3ed6f1da69269a86 # v2.4
with:
paths: "${{ github.workspace }}/merged-test-results.xml"

View file

@ -12,6 +12,15 @@ on:
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/integration-tests.yml' # This workflow
- '.github/actions/setup-ollama/action.yml'
schedule:
- cron: '0 0 * * *' # Daily at 12 AM UTC
workflow_dispatch:
inputs:
test-all-client-versions:
description: 'Test against both the latest and published versions'
type: boolean
default: false
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
@ -45,6 +54,7 @@ jobs:
test-type: ${{ fromJson(needs.discover-tests.outputs.test-type) }}
client-type: [library, server]
python-version: ["3.12", "3.13"]
client-version: ${{ (github.event_name == 'schedule' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
steps:
- name: Checkout repository
@ -54,6 +64,7 @@ jobs:
uses: ./.github/actions/setup-runner
with:
python-version: ${{ matrix.python-version }}
client-version: ${{ matrix.client-version }}
- name: Setup ollama
uses: ./.github/actions/setup-ollama
@ -108,7 +119,7 @@ jobs:
if: ${{ always() }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}-${{ matrix.python-version }}
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}-${{ matrix.python-version }}-${{ matrix.client-version }}
path: |
*.log
retention-days: 1

View file

@ -1,69 +0,0 @@
name: auto-tests
on:
# pull_request:
workflow_dispatch:
inputs:
commit_sha:
description: 'Specific Commit SHA to trigger on'
required: false
default: $GITHUB_SHA # default to the last commit of $GITHUB_REF branch
jobs:
test-llama-stack-as-library:
runs-on: ubuntu-latest
env:
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
TAVILY_SEARCH_API_KEY: ${{ secrets.TAVILY_SEARCH_API_KEY }}
strategy:
matrix:
provider: [fireworks, together]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
ref: ${{ github.event.inputs.commit_sha }}
- name: Echo commit SHA
run: |
echo "Triggered on commit SHA: ${{ github.event.inputs.commit_sha }}"
git rev-parse HEAD
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt pytest
pip install -e .
- name: Build providers
run: |
llama stack build --template ${{ matrix.provider }} --image-type venv
- name: Install the latest llama-stack-client & llama-models packages
run: |
pip install -e git+https://github.com/meta-llama/llama-stack-client-python.git#egg=llama-stack-client
pip install -e git+https://github.com/meta-llama/llama-models.git#egg=llama-models
- name: Run client-sdk test
working-directory: "${{ github.workspace }}"
env:
REPORT_OUTPUT: md_report.md
shell: bash
run: |
pip install --upgrade pytest-md-report
echo "REPORT_FILE=${REPORT_OUTPUT}" >> "$GITHUB_ENV"
export INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
LLAMA_STACK_CONFIG=./llama_stack/templates/${{ matrix.provider }}/run.yaml pytest --md-report --md-report-verbose=1 ./tests/client-sdk/inference/ --md-report-output "$REPORT_OUTPUT"
- name: Output reports to the job summary
if: always()
shell: bash
run: |
if [ -f "$REPORT_FILE" ]; then
echo "<details><summary> Test Report for ${{ matrix.provider }} </summary>" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
cat "$REPORT_FILE" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "</details>" >> $GITHUB_STEP_SUMMARY
fi

View file

@ -112,7 +112,7 @@ uv run pre-commit run --all-files
## Running tests
You can find the Llama Stack testing documentation here [here](tests/README.md).
You can find the Llama Stack testing documentation [here](https://github.com/meta-llama/llama-stack/blob/main/tests/README.md).
## Adding a new dependency to the project

View file

@ -11340,6 +11340,9 @@
},
"embedding_dimension": {
"type": "integer"
},
"vector_db_name": {
"type": "string"
}
},
"additionalProperties": false,
@ -13590,10 +13593,6 @@
"provider_id": {
"type": "string",
"description": "The ID of the provider to use for this vector store."
},
"provider_vector_db_id": {
"type": "string",
"description": "The provider-specific vector database ID."
}
},
"additionalProperties": false,
@ -15634,6 +15633,10 @@
"type": "string",
"description": "The identifier of the provider."
},
"vector_db_name": {
"type": "string",
"description": "The name of the vector database."
},
"provider_vector_db_id": {
"type": "string",
"description": "The identifier of the vector database in the provider."

View file

@ -7984,6 +7984,8 @@ components:
type: string
embedding_dimension:
type: integer
vector_db_name:
type: string
additionalProperties: false
required:
- identifier
@ -9494,10 +9496,6 @@ components:
type: string
description: >-
The ID of the provider to use for this vector store.
provider_vector_db_id:
type: string
description: >-
The provider-specific vector database ID.
additionalProperties: false
required:
- name
@ -10945,6 +10943,9 @@ components:
provider_id:
type: string
description: The identifier of the provider.
vector_db_name:
type: string
description: The name of the vector database.
provider_vector_db_id:
type: string
description: >-

View file

@ -0,0 +1,6 @@
# Eval Providers
This section contains documentation for all available providers for the **eval** API.
- [inline::meta-reference](inline_meta-reference.md)
- [remote::nvidia](remote_nvidia.md)

View file

@ -0,0 +1,21 @@
# inline::meta-reference
## Description
Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | |
## Sample Configuration
```yaml
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/meta_reference_eval.db
```

View file

@ -0,0 +1,19 @@
# remote::nvidia
## Description
NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `evaluator_url` | `<class 'str'>` | No | http://0.0.0.0:7331 | The url for accessing the evaluator service |
## Sample Configuration
```yaml
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:=http://localhost:7331}
```

View file

@ -0,0 +1,33 @@
# Advanced APIs
## Post-training
Fine-tunes a model.
```{toctree}
:maxdepth: 1
post_training/index
```
## Eval
Generates outputs (via Inference or Agents) and perform scoring.
```{toctree}
:maxdepth: 1
eval/index
```
```{include} evaluation_concepts.md
:start-after: ## Evaluation Concepts
```
## Scoring
Evaluates the outputs of the system.
```{toctree}
:maxdepth: 1
scoring/index
```

View file

@ -0,0 +1,7 @@
# Post_Training Providers
This section contains documentation for all available providers for the **post_training** API.
- [inline::huggingface](inline_huggingface.md)
- [inline::torchtune](inline_torchtune.md)
- [remote::nvidia](remote_nvidia.md)

View file

@ -0,0 +1,33 @@
# inline::huggingface
## Description
HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `device` | `<class 'str'>` | No | cuda | |
| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | |
| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | |
| `chat_template` | `<class 'str'>` | No | |
| `model_specific_config` | `<class 'dict'>` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | |
| `max_seq_length` | `<class 'int'>` | No | 2048 | |
| `gradient_checkpointing` | `<class 'bool'>` | No | False | |
| `save_total_limit` | `<class 'int'>` | No | 3 | |
| `logging_steps` | `<class 'int'>` | No | 10 | |
| `warmup_ratio` | `<class 'float'>` | No | 0.1 | |
| `weight_decay` | `<class 'float'>` | No | 0.01 | |
| `dataloader_num_workers` | `<class 'int'>` | No | 4 | |
| `dataloader_pin_memory` | `<class 'bool'>` | No | True | |
## Sample Configuration
```yaml
checkpoint_format: huggingface
distributed_backend: null
device: cpu
```

View file

@ -0,0 +1,20 @@
# inline::torchtune
## Description
TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `torch_seed` | `int \| None` | No | | |
| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | |
## Sample Configuration
```yaml
checkpoint_format: meta
```

View file

@ -0,0 +1,28 @@
# remote::nvidia
## Description
NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `api_key` | `str \| None` | No | | The NVIDIA API key. |
| `dataset_namespace` | `str \| None` | No | default | The NVIDIA dataset namespace. |
| `project_id` | `str \| None` | No | test-example-model@v1 | The NVIDIA project ID. |
| `customizer_url` | `str \| None` | No | | Base URL for the NeMo Customizer API |
| `timeout` | `<class 'int'>` | No | 300 | Timeout for the NVIDIA Post Training API |
| `max_retries` | `<class 'int'>` | No | 3 | Maximum number of retries for the NVIDIA Post Training API |
| `output_model_dir` | `<class 'str'>` | No | test-example-model@v1 | Directory to save the output model |
## Sample Configuration
```yaml
api_key: ${env.NVIDIA_API_KEY:=}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:=default}
project_id: ${env.NVIDIA_PROJECT_ID:=test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:=http://nemo.test}
```

View file

@ -0,0 +1,7 @@
# Scoring Providers
This section contains documentation for all available providers for the **scoring** API.
- [inline::basic](inline_basic.md)
- [inline::braintrust](inline_braintrust.md)
- [inline::llm-as-judge](inline_llm-as-judge.md)

View file

@ -0,0 +1,13 @@
# inline::basic
## Description
Basic scoring provider for simple evaluation metrics and scoring functions.
## Sample Configuration
```yaml
{}
```

View file

@ -0,0 +1,19 @@
# inline::braintrust
## Description
Braintrust scoring provider for evaluation and scoring using the Braintrust platform.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `openai_api_key` | `str \| None` | No | | The OpenAI API Key |
## Sample Configuration
```yaml
openai_api_key: ${env.OPENAI_API_KEY:=}
```

View file

@ -0,0 +1,13 @@
# inline::llm-as-judge
## Description
LLM-as-judge scoring provider that uses language models to evaluate and score responses.
## Sample Configuration
```yaml
{}
```

View file

@ -1,4 +1,4 @@
# Building AI Applications (Examples)
# AI Application Examples
Llama Stack provides all the building blocks needed to create sophisticated AI applications.
@ -27,4 +27,5 @@ tools
evals
telemetry
safety
```
playground/index
```

View file

@ -1,4 +1,4 @@
# Llama Stack Playground
## Llama Stack Playground
```{note}
The Llama Stack Playground is currently experimental and subject to change. We welcome feedback and contributions to help improve it.
@ -9,7 +9,7 @@ The Llama Stack Playground is an simple interface which aims to:
- Demo **end-to-end** application code to help users get started to build their own applications
- Provide an **UI** to help users inspect and understand Llama Stack API providers and resources
## Key Features
### Key Features
#### Playground
Interactive pages for users to play with and explore Llama Stack API capabilities.
@ -90,7 +90,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
- Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources.
- Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources.
## Starting the Llama Stack Playground
### Starting the Llama Stack Playground
To start the Llama Stack Playground, run the following commands:

View file

@ -1,31 +1,39 @@
# Why Llama Stack?
## Llama Stack architecture
Building production AI applications today requires solving multiple challenges:
**Infrastructure Complexity**
- Running large language models efficiently requires specialized infrastructure.
- Different deployment scenarios (local development, cloud, edge) need different solutions.
- Moving from development to production often requires significant rework.
**Essential Capabilities**
- Safety guardrails and content filtering are necessary in an enterprise setting.
- Just model inference is not enough - Knowledge retrieval and RAG capabilities are required.
- Nearly any application needs composable multi-step workflows.
- Finally, without monitoring, observability and evaluation, you end up operating in the dark.
**Lack of Flexibility and Choice**
- Directly integrating with multiple providers creates tight coupling.
- Different providers have different APIs and abstractions.
- Changing providers requires significant code changes.
### Our Solution: A Universal Stack
Llama Stack allows you to build different layers of distributions for your AI workloads using various SDKs and API providers.
```{image} ../../_static/llama-stack.png
:alt: Llama Stack
:width: 400px
```
### Benefits of Llama stack
#### Current challenges in custom AI applications
Building production AI applications today requires solving multiple challenges:
**Infrastructure Complexity**
- Running large language models efficiently requires specialized infrastructure.
- Different deployment scenarios (local development, cloud, edge) need different solutions.
- Moving from development to production often requires significant rework.
**Essential Capabilities**
- Safety guardrails and content filtering are necessary in an enterprise setting.
- Just model inference is not enough - Knowledge retrieval and RAG capabilities are required.
- Nearly any application needs composable multi-step workflows.
- Without monitoring, observability and evaluation, you end up operating in the dark.
**Lack of Flexibility and Choice**
- Directly integrating with multiple providers creates tight coupling.
- Different providers have different APIs and abstractions.
- Changing providers requires significant code changes.
#### Our Solution: A Universal Stack
Llama Stack addresses these challenges through a service-oriented, API-first approach:
**Develop Anywhere, Deploy Everywhere**
@ -59,4 +67,4 @@ Llama Stack addresses these challenges through a service-oriented, API-first app
- **Turnkey Solutions**: Easy to deploy built in solutions for popular deployment scenarios
With Llama Stack, you can focus on building your application while we handle the infrastructure complexity, essential capabilities, and provider integrations.
With Llama Stack, you can focus on building your application while we handle the infrastructure complexity, essential capabilities, and provider integrations.

View file

@ -2,6 +2,10 @@
Given Llama Stack's service-oriented philosophy, a few concepts and workflows arise which may not feel completely natural in the LLM landscape, especially if you are coming with a background in other frameworks.
```{include} architecture.md
:start-after: ## Llama Stack architecture
```
```{include} apis.md
:start-after: ## APIs
```
@ -10,14 +14,10 @@ Given Llama Stack's service-oriented philosophy, a few concepts and workflows ar
:start-after: ## API Providers
```
```{include} resources.md
:start-after: ## Resources
```
```{include} distributions.md
:start-after: ## Distributions
```
```{include} evaluation_concepts.md
:start-after: ## Evaluation Concepts
```{include} resources.md
:start-after: ## Resources
```

View file

@ -52,7 +52,18 @@ extensions = [
"sphinxcontrib.redoc",
"sphinxcontrib.mermaid",
"sphinxcontrib.video",
"sphinx_reredirects"
]
redirects = {
"providers/post_training/index": "../../advanced_apis/post_training/index.html",
"providers/eval/index": "../../advanced_apis/eval/index.html",
"providers/scoring/index": "../../advanced_apis/scoring/index.html",
"playground/index": "../../building_applications/playground/index.html",
"openai/index": "../../providers/index.html#openai-api-compatibility",
"introduction/index": "../concepts/index.html#llama-stack-architecture"
}
myst_enable_extensions = ["colon_fence"]
html_theme = "sphinx_rtd_theme"

View file

@ -0,0 +1,4 @@
# Deployment Examples
```{include} kubernetes_deployment.md
```

View file

@ -1,4 +1,4 @@
# Kubernetes Deployment Guide
## Kubernetes Deployment Guide
Instead of starting the Llama Stack and vLLM servers locally. We can deploy them in a Kubernetes cluster.

View file

@ -6,14 +6,9 @@ This section provides an overview of the distributions available in Llama Stack.
```{toctree}
:maxdepth: 3
list_of_distributions
building_distro
customizing_run_yaml
importing_as_library
configuration
customizing_run_yaml
list_of_distributions
kubernetes_deployment
building_distro
on_device_distro
remote_hosted_distro
self_hosted_distro
```

View file

@ -28,5 +28,4 @@ If you have built a container image and want to deploy it in a Kubernetes cluste
importing_as_library
configuration
kubernetes_deployment
```

View file

@ -1,4 +1,4 @@
# Detailed Tutorial
## Detailed Tutorial
In this guide, we'll walk through how you can use the Llama Stack (server and client SDK) to test a simple agent.
A Llama Stack agent is a simple integrated system that can perform tasks by combining a Llama model for reasoning with
@ -10,7 +10,7 @@ Llama Stack is a stateful service with REST APIs to support seamless transition
In this guide, we'll walk through how to build a RAG agent locally using Llama Stack with [Ollama](https://ollama.com/)
as the inference [provider](../providers/index.md#inference) for a Llama Model.
## Step 1: Installation and Setup
### Step 1: Installation and Setup
Install Ollama by following the instructions on the [Ollama website](https://ollama.com/download), then
download Llama 3.2 3B model, and then start the Ollama service.
@ -45,7 +45,7 @@ Setup your virtual environment.
uv sync --python 3.12
source .venv/bin/activate
```
## Step 2: Run Llama Stack
### Step 2: Run Llama Stack
Llama Stack is a server that exposes multiple APIs, you connect with it using the Llama Stack client SDK.
::::{tab-set}
@ -132,7 +132,7 @@ Now you can use the Llama Stack client to run inference and build agents!
You can reuse the server setup or use the [Llama Stack Client](https://github.com/meta-llama/llama-stack-client-python/).
Note that the client package is already included in the `llama-stack` package.
## Step 3: Run Client CLI
### Step 3: Run Client CLI
Open a new terminal and navigate to the same directory you started the server from. Then set up a new or activate your
existing server virtual environment.
@ -232,7 +232,7 @@ OpenAIChatCompletion(
)
```
## Step 4: Run the Demos
### Step 4: Run the Demos
Note that these demos show the [Python Client SDK](../references/python_sdk_reference/index.md).
Other SDKs are also available, please refer to the [Client SDK](../index.md#client-sdks) list for the complete options.
@ -242,7 +242,7 @@ Other SDKs are also available, please refer to the [Client SDK](../index.md#clie
:::{tab-item} Basic Inference
Now you can run inference using the Llama Stack client SDK.
### i. Create the Script
#### i. Create the Script
Create a file `inference.py` and add the following code:
```python
@ -269,7 +269,7 @@ response = client.chat.completions.create(
print(response)
```
### ii. Run the Script
#### ii. Run the Script
Let's run the script using `uv`
```bash
uv run python inference.py
@ -283,7 +283,7 @@ OpenAIChatCompletion(id='chatcmpl-30cd0f28-a2ad-4b6d-934b-13707fc60ebf', choices
:::{tab-item} Build a Simple Agent
Next we can move beyond simple inference and build an agent that can perform tasks using the Llama Stack server.
### i. Create the Script
#### i. Create the Script
Create a file `agent.py` and add the following code:
```python
@ -455,7 +455,7 @@ uv run python agent.py
For our last demo, we can build a RAG agent that can answer questions about the Torchtune project using the documents
in a vector database.
### i. Create the Script
#### i. Create the Script
Create a file `rag_agent.py` and add the following code:
```python
@ -533,7 +533,7 @@ for t in turns:
for event in AgentEventLogger().log(stream):
event.print()
```
### ii. Run the Script
#### ii. Run the Script
Let's run the script using `uv`
```bash
uv run python rag_agent.py

View file

@ -1,123 +1,13 @@
# Quickstart
# Getting Started
Get started with Llama Stack in minutes!
Llama Stack is a stateful service with REST APIs to support the seamless transition of AI applications across different
environments. You can build and test using a local server first and deploy to a hosted endpoint for production.
In this guide, we'll walk through how to build a RAG application locally using Llama Stack with [Ollama](https://ollama.com/)
as the inference [provider](../providers/inference/index) for a Llama Model.
**💡 Notebook Version:** You can also follow this quickstart guide in a Jupyter notebook format: [quick_start.ipynb](https://github.com/meta-llama/llama-stack/blob/main/docs/quick_start.ipynb)
#### Step 1: Install and setup
1. Install [uv](https://docs.astral.sh/uv/)
2. Run inference on a Llama model with [Ollama](https://ollama.com/download)
```bash
ollama run llama3.2:3b --keepalive 60m
```{include} quickstart.md
:start-after: ## Quickstart
```
#### Step 2: Run the Llama Stack server
We will use `uv` to run the Llama Stack server.
```bash
INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
```{include} libraries.md
:start-after: ## Libraries (SDKs)
```
#### Step 3: Run the demo
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
```python
from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient
vector_db_id = "my_demo_vector_db"
client = LlamaStackClient(base_url="http://localhost:8321")
models = client.models.list()
# Select the first LLM and first embedding models
model_id = next(m for m in models if m.model_type == "llm").identifier
embedding_model_id = (
em := next(m for m in models if m.model_type == "embedding")
).identifier
embedding_dimension = em.metadata["embedding_dimension"]
_ = client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
provider_id="faiss",
)
source = "https://www.paulgraham.com/greatwork.html"
print("rag_tool> Ingesting document:", source)
document = RAGDocument(
document_id="document_1",
content=source,
mime_type="text/html",
metadata={},
)
client.tool_runtime.rag_tool.insert(
documents=[document],
vector_db_id=vector_db_id,
chunk_size_in_tokens=50,
)
agent = Agent(
client,
model=model_id,
instructions="You are a helpful assistant",
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": [vector_db_id]},
}
],
)
prompt = "How do you do great work?"
print("prompt>", prompt)
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=agent.create_session("rag_session"),
stream=True,
)
for log in AgentEventLogger().log(response):
log.print()
```{include} detailed_tutorial.md
:start-after: ## Detailed Tutorial
```
We will use `uv` to run the script
```
uv run --with llama-stack-client,fire,requests demo_script.py
```
And you should see output like below.
```
rag_tool> Ingesting document: https://www.paulgraham.com/greatwork.html
prompt> How do you do great work?
inference> [knowledge_search(query="What is the key to doing great work")]
tool_execution> Tool:knowledge_search Args:{'query': 'What is the key to doing great work'}
tool_execution> Tool:knowledge_search Response:[TextContentItem(text='knowledge_search tool found 5 chunks:\nBEGIN of knowledge_search tool results.\n', type='text'), TextContentItem(text="Result 1:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 2:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 3:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 4:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 5:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text='END of knowledge_search tool results.\n', type='text')]
inference> Based on the search results, it seems that doing great work means doing something important so well that you expand people's ideas of what's possible. However, there is no clear threshold for importance, and it can be difficult to judge at the time.
To further clarify, I would suggest that doing great work involves:
* Completing tasks with high quality and attention to detail
* Expanding on existing knowledge or ideas
* Making a positive impact on others through your work
* Striving for excellence and continuous improvement
Ultimately, great work is about making a meaningful contribution and leaving a lasting impression.
```
Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳
## Next Steps
Now you're ready to dive deeper into Llama Stack!
- Explore the [Detailed Tutorial](./detailed_tutorial.md).
- Try the [Getting Started Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb).
- Browse more [Notebooks on GitHub](https://github.com/meta-llama/llama-stack/tree/main/docs/notebooks).
- Learn about Llama Stack [Concepts](../concepts/index.md).
- Discover how to [Build Llama Stacks](../distributions/index.md).
- Refer to our [References](../references/index.md) for details on the Llama CLI and Python SDK.
- Check out the [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository for example applications and tutorials.

View file

@ -0,0 +1,10 @@
## Libraries (SDKs)
We have a number of client-side SDKs available for different languages.
| **Language** | **Client SDK** | **Package** |
| :----: | :----: | :----: |
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/)
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift/tree/latest-release) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift)
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client)
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)

View file

@ -0,0 +1,123 @@
## Quickstart
Get started with Llama Stack in minutes!
Llama Stack is a stateful service with REST APIs to support the seamless transition of AI applications across different
environments. You can build and test using a local server first and deploy to a hosted endpoint for production.
In this guide, we'll walk through how to build a RAG application locally using Llama Stack with [Ollama](https://ollama.com/)
as the inference [provider](../providers/inference/index) for a Llama Model.
**💡 Notebook Version:** You can also follow this quickstart guide in a Jupyter notebook format: [quick_start.ipynb](https://github.com/meta-llama/llama-stack/blob/main/docs/quick_start.ipynb)
#### Step 1: Install and setup
1. Install [uv](https://docs.astral.sh/uv/)
2. Run inference on a Llama model with [Ollama](https://ollama.com/download)
```bash
ollama run llama3.2:3b --keepalive 60m
```
#### Step 2: Run the Llama Stack server
We will use `uv` to run the Llama Stack server.
```bash
INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
```
#### Step 3: Run the demo
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
```python
from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient
vector_db_id = "my_demo_vector_db"
client = LlamaStackClient(base_url="http://localhost:8321")
models = client.models.list()
# Select the first LLM and first embedding models
model_id = next(m for m in models if m.model_type == "llm").identifier
embedding_model_id = (
em := next(m for m in models if m.model_type == "embedding")
).identifier
embedding_dimension = em.metadata["embedding_dimension"]
_ = client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
provider_id="faiss",
)
source = "https://www.paulgraham.com/greatwork.html"
print("rag_tool> Ingesting document:", source)
document = RAGDocument(
document_id="document_1",
content=source,
mime_type="text/html",
metadata={},
)
client.tool_runtime.rag_tool.insert(
documents=[document],
vector_db_id=vector_db_id,
chunk_size_in_tokens=50,
)
agent = Agent(
client,
model=model_id,
instructions="You are a helpful assistant",
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": [vector_db_id]},
}
],
)
prompt = "How do you do great work?"
print("prompt>", prompt)
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=agent.create_session("rag_session"),
stream=True,
)
for log in AgentEventLogger().log(response):
log.print()
```
We will use `uv` to run the script
```
uv run --with llama-stack-client,fire,requests demo_script.py
```
And you should see output like below.
```
rag_tool> Ingesting document: https://www.paulgraham.com/greatwork.html
prompt> How do you do great work?
inference> [knowledge_search(query="What is the key to doing great work")]
tool_execution> Tool:knowledge_search Args:{'query': 'What is the key to doing great work'}
tool_execution> Tool:knowledge_search Response:[TextContentItem(text='knowledge_search tool found 5 chunks:\nBEGIN of knowledge_search tool results.\n', type='text'), TextContentItem(text="Result 1:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 2:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 3:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 4:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 5:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text='END of knowledge_search tool results.\n', type='text')]
inference> Based on the search results, it seems that doing great work means doing something important so well that you expand people's ideas of what's possible. However, there is no clear threshold for importance, and it can be difficult to judge at the time.
To further clarify, I would suggest that doing great work involves:
* Completing tasks with high quality and attention to detail
* Expanding on existing knowledge or ideas
* Making a positive impact on others through your work
* Striving for excellence and continuous improvement
Ultimately, great work is about making a meaningful contribution and leaving a lasting impression.
```
Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳
### Next Steps
Now you're ready to dive deeper into Llama Stack!
- Explore the [Detailed Tutorial](./detailed_tutorial.md).
- Try the [Getting Started Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb).
- Browse more [Notebooks on GitHub](https://github.com/meta-llama/llama-stack/tree/main/docs/notebooks).
- Learn about Llama Stack [Concepts](../concepts/index.md).
- Discover how to [Build Llama Stacks](../distributions/index.md).
- Refer to our [References](../references/index.md) for details on the Llama CLI and Python SDK.
- Check out the [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository for example applications and tutorials.

View file

@ -40,17 +40,6 @@ Kotlin.
- Ready to build? Check out the [Quick Start](getting_started/index) to get started.
- Want to contribute? See the [Contributing](contributing/index) guide.
## Client SDKs
We have a number of client-side SDKs available for different languages.
| **Language** | **Client SDK** | **Package** |
| :----: | :----: | :----: |
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/)
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift/tree/latest-release) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift)
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client)
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)
## Supported Llama Stack Implementations
A number of "adapters" are available for some popular Inference and Vector Store providers. For other APIs (particularly Safety and Agents), we provide *reference implementations* you can use to get started. We expect this list to grow over time. We are slowly onboarding more providers to the ecosystem as we get more confidence in the APIs.
@ -133,14 +122,12 @@ A number of "adapters" are available for some popular Inference and Vector Store
self
getting_started/index
getting_started/detailed_tutorial
introduction/index
concepts/index
openai/index
providers/index
distributions/index
advanced_apis/index
building_applications/index
playground/index
deploying/index
contributing/index
references/index
```

View file

@ -1,4 +1,4 @@
# Providers Overview
# API Providers Overview
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Meta Reference, Ollama, Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, OpenAI, Anthropic, Gemini, WatsonX, etc.),
@ -13,13 +13,25 @@ Providers come in two flavors:
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
## External Providers
Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently.
```{toctree}
:maxdepth: 1
external
external.md
```
```{include} openai.md
:start-after: ## OpenAI API Compatibility
```
## Inference
Runs inference with an LLM.
```{toctree}
:maxdepth: 1
inference/index
```
## Agents
@ -40,33 +52,6 @@ Interfaces with datasets and data loaders.
datasetio/index
```
## Eval
Generates outputs (via Inference or Agents) and perform scoring.
```{toctree}
:maxdepth: 1
eval/index
```
## Inference
Runs inference with an LLM.
```{toctree}
:maxdepth: 1
inference/index
```
## Post Training
Fine-tunes a model.
```{toctree}
:maxdepth: 1
post_training/index
```
## Safety
Applies safety policies to the output at a Systems (not only model) level.
@ -76,15 +61,6 @@ Applies safety policies to the output at a Systems (not only model) level.
safety/index
```
## Scoring
Evaluates the outputs of the system.
```{toctree}
:maxdepth: 1
scoring/index
```
## Telemetry
Collects telemetry data from the system.
@ -94,15 +70,6 @@ Collects telemetry data from the system.
telemetry/index
```
## Tool Runtime
Is associated with the ToolGroup resouces.
```{toctree}
:maxdepth: 1
tool_runtime/index
```
## Vector IO
Vector IO refers to operations on vector databases, such as adding documents, searching, and deleting documents.
@ -114,3 +81,12 @@ io and database are used to store and retrieve documents for retrieval.
vector_io/index
```
## Tool Runtime
Is associated with the ToolGroup resources.
```{toctree}
:maxdepth: 1
tool_runtime/index
```

View file

@ -1,14 +1,14 @@
# OpenAI API Compatibility
## OpenAI API Compatibility
## Server path
### Server path
Llama Stack exposes an OpenAI-compatible API endpoint at `/v1/openai/v1`. So, for a Llama Stack server running locally on port `8321`, the full url to the OpenAI-compatible API endpoint is `http://localhost:8321/v1/openai/v1`.
## Clients
### Clients
You should be able to use any client that speaks OpenAI APIs with Llama Stack. We regularly test with the official Llama Stack clients as well as OpenAI's official Python client.
### Llama Stack Client
#### Llama Stack Client
When using the Llama Stack client, set the `base_url` to the root of your Llama Stack server. It will automatically route OpenAI-compatible requests to the right server endpoint for you.
@ -18,7 +18,7 @@ from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url="http://localhost:8321")
```
### OpenAI Client
#### OpenAI Client
When using an OpenAI client, set the `base_url` to the `/v1/openai/v1` path on your Llama Stack server.
@ -30,9 +30,9 @@ client = OpenAI(base_url="http://localhost:8321/v1/openai/v1", api_key="none")
Regardless of the client you choose, the following code examples should all work the same.
## APIs implemented
### APIs implemented
### Models
#### Models
Many of the APIs require you to pass in a model parameter. To see the list of models available in your Llama Stack server:
@ -40,13 +40,13 @@ Many of the APIs require you to pass in a model parameter. To see the list of mo
models = client.models.list()
```
### Responses
#### Responses
:::{note}
The Responses API implementation is still in active development. While it is quite usable, there are still unimplemented parts of the API. We'd love feedback on any use-cases you try that do not work to help prioritize the pieces left to implement. Please open issues in the [meta-llama/llama-stack](https://github.com/meta-llama/llama-stack) GitHub repository with details of anything that does not work.
:::
#### Simple inference
##### Simple inference
Request:
@ -66,7 +66,7 @@ Syntax whispers secrets sweet
Code's gentle silence
```
#### Structured Output
##### Structured Output
Request:
@ -106,9 +106,9 @@ Example output:
{ "participants": ["Alice", "Bob"] }
```
### Chat Completions
#### Chat Completions
#### Simple inference
##### Simple inference
Request:
@ -129,7 +129,7 @@ Logic flows like a river
Code's gentle beauty
```
#### Structured Output
##### Structured Output
Request:
@ -170,9 +170,9 @@ Example output:
{ "participants": ["Alice", "Bob"] }
```
### Completions
#### Completions
#### Simple inference
##### Simple inference
Request:

View file

@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
```yaml
uri: ${env.MILVUS_ENDPOINT}
token: ${env.MILVUS_TOKEN}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db
```

View file

@ -9,7 +9,8 @@ The `llama-stack-client` CLI allows you to query information about the distribut
llama-stack-client
Usage: llama-stack-client [OPTIONS] COMMAND [ARGS]...
Welcome to the LlamaStackClient CLI
Welcome to the llama-stack-client CLI - a command-line interface for
interacting with Llama Stack
Options:
--version Show the version and exit.
@ -35,6 +36,7 @@ Commands:
```
### `llama-stack-client configure`
Configure Llama Stack Client CLI.
```bash
llama-stack-client configure
> Enter the host name of the Llama Stack distribution server: localhost
@ -42,7 +44,24 @@ llama-stack-client configure
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
```
Optional arguments:
- `--endpoint`: Llama Stack distribution endpoint
- `--api-key`: Llama Stack distribution API key
## `llama-stack-client inspect version`
Inspect server configuration.
```bash
llama-stack-client inspect version
```
```bash
VersionInfo(version='0.2.14')
```
### `llama-stack-client providers list`
Show available providers on distribution endpoint
```bash
llama-stack-client providers list
```
@ -66,9 +85,74 @@ llama-stack-client providers list
+-----------+----------------+-----------------+
```
### `llama-stack-client providers inspect`
Show specific provider configuration on distribution endpoint
```bash
llama-stack-client providers inspect <provider_id>
```
## Inference
Inference (chat).
### `llama-stack-client inference chat-completion`
Show available inference chat completion endpoints on distribution endpoint
```bash
llama-stack-client inference chat-completion --message <message> [--stream] [--session] [--model-id]
```
```bash
OpenAIChatCompletion(
id='chatcmpl-aacd11f3-8899-4ec5-ac5b-e655132f6891',
choices=[
OpenAIChatCompletionChoice(
finish_reason='stop',
index=0,
message=OpenAIChatCompletionChoiceMessageOpenAIAssistantMessageParam(
role='assistant',
content='The captain of the whaleship Pequod in Nathaniel Hawthorne\'s novel "Moby-Dick" is Captain
Ahab. He\'s a vengeful and obsessive old sailor who\'s determined to hunt down and kill the white sperm whale
Moby-Dick, whom he\'s lost his leg to in a previous encounter.',
name=None,
tool_calls=None,
refusal=None,
annotations=None,
audio=None,
function_call=None
),
logprobs=None
)
],
created=1752578797,
model='llama3.2:3b-instruct-fp16',
object='chat.completion',
service_tier=None,
system_fingerprint='fp_ollama',
usage={
'completion_tokens': 67,
'prompt_tokens': 33,
'total_tokens': 100,
'completion_tokens_details': None,
'prompt_tokens_details': None
}
)
```
Required arguments:
**Note:** At least one of these parameters is required for chat completion
- `--message`: Message
- `--session`: Start a Chat Session
Optional arguments:
- `--stream`: Stream
- `--model-id`: Model ID
## Model Management
Manage GenAI models.
### `llama-stack-client models list`
Show available llama models at distribution endpoint
```bash
llama-stack-client models list
```
@ -85,6 +169,7 @@ Total models: 1
```
### `llama-stack-client models get`
Show details of a specific model at the distribution endpoint
```bash
llama-stack-client models get Llama3.1-8B-Instruct
```
@ -105,69 +190,92 @@ Model RandomModel is not found at distribution endpoint host:port. Please ensure
```
### `llama-stack-client models register`
Register a new model at distribution endpoint
```bash
llama-stack-client models register <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>]
llama-stack-client models register <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>] [--model-type <model_type>]
```
### `llama-stack-client models update`
Required arguments:
- `MODEL_ID`: Model ID
- `--provider-id`: Provider ID for the model
Optional arguments:
- `--provider-model-id`: Provider's model ID
- `--metadata`: JSON metadata for the model
- `--model-type`: Model type: `llm`, `embedding`
### `llama-stack-client models unregister`
Unregister a model from distribution endpoint
```bash
llama-stack-client models update <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>]
```
### `llama-stack-client models delete`
```bash
llama-stack-client models delete <model_id>
llama-stack-client models unregister <model_id>
```
## Vector DB Management
Manage vector databases.
### `llama-stack-client vector_dbs list`
Show available vector dbs on distribution endpoint
```bash
llama-stack-client vector_dbs list
```
```
+--------------+----------------+---------------------+---------------+------------------------+
| identifier | provider_id | provider_resource_id| vector_db_type| params |
+==============+================+=====================+===============+========================+
| test_bank | meta-reference | test_bank | vector | embedding_model: all-MiniLM-L6-v2
embedding_dimension: 384|
+--------------+----------------+---------------------+---------------+------------------------+
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ identifier ┃ provider_id ┃ provider_resource_id ┃ vector_db_type ┃ params ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ my_demo_vector_db │ faiss │ my_demo_vector_db │ │ embedding_dimension: 384 │
│ │ │ │ │ embedding_model: all-MiniLM-L6-v2 │
│ │ │ │ │ type: vector_db │
│ │ │ │ │ │
└──────────────────────────┴─────────────┴──────────────────────────┴────────────────┴───────────────────────────────────┘
```
### `llama-stack-client vector_dbs register`
Create a new vector db
```bash
llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>]
```
Required arguments:
- `VECTOR_DB_ID`: Vector DB ID
Optional arguments:
- `--provider-id`: Provider ID for the vector db
- `--provider-vector-db-id`: Provider's vector db ID
- `--embedding-model`: Embedding model to use. Default: "all-MiniLM-L6-v2"
- `--embedding-model`: Embedding model to use. Default: `all-MiniLM-L6-v2`
- `--embedding-dimension`: Dimension of embeddings. Default: 384
### `llama-stack-client vector_dbs unregister`
Delete a vector db
```bash
llama-stack-client vector_dbs unregister <vector-db-id>
```
Required arguments:
- `VECTOR_DB_ID`: Vector DB ID
## Shield Management
Manage safety shield services.
### `llama-stack-client shields list`
Show available safety shields on distribution endpoint
```bash
llama-stack-client shields list
```
```
+--------------+----------+----------------+-------------+
| identifier | params | provider_id | type |
+==============+==========+================+=============+
| llama_guard | {} | meta-reference | llama_guard |
+--------------+----------+----------------+-------------+
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ identifier ┃ provider_alias ┃ params ┃ provider_id ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ollama │ ollama/llama-guard3:1b │ │ llama-guard │
└──────────────────────────────────┴───────────────────────────────────────────────────────────────────────┴───────────────────────┴────────────────────────────────────┘
```
### `llama-stack-client shields register`
Register a new safety shield
```bash
llama-stack-client shields register --shield-id <shield-id> [--provider-id <provider-id>] [--provider-shield-id <provider-shield-id>] [--params <params>]
```
@ -180,41 +288,29 @@ Optional arguments:
- `--provider-shield-id`: Provider's shield ID
- `--params`: JSON configuration parameters for the shield
## Eval Task Management
### `llama-stack-client benchmarks list`
```bash
llama-stack-client benchmarks list
```
### `llama-stack-client benchmarks register`
```bash
llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
```
Required arguments:
- `--eval-task-id`: ID of the eval task
- `--dataset-id`: ID of the dataset to evaluate
- `--scoring-functions`: One or more scoring functions to use for evaluation
Optional arguments:
- `--provider-id`: Provider ID for the eval task
- `--provider-eval-task-id`: Provider's eval task ID
- `--metadata`: Metadata for the eval task in JSON format
## Eval execution
Run evaluation tasks.
### `llama-stack-client eval run-benchmark`
Run a evaluation benchmark task
```bash
llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> --model-id <model-id> [--num-examples <num>] [--visualize] [--repeat-penalty <repeat-penalty>] [--top-p <top-p>] [--max-tokens <max-tokens>]
```
Required arguments:
- `--eval-task-config`: Path to the eval task config file in JSON format
- `--output-dir`: Path to the directory where evaluation results will be saved
- `--model-id`: model id to run the benchmark eval on
Optional arguments:
- `--num-examples`: Number of examples to evaluate (useful for debugging)
- `--visualize`: If set, visualizes evaluation results after completion
- `--repeat-penalty`: repeat-penalty in the sampling params to run generation
- `--top-p`: top-p in the sampling params to run generation
- `--max-tokens`: max-tokens in the sampling params to run generation
- `--temperature`: temperature in the sampling params to run generation
Example benchmark_config.json:
```json
@ -231,21 +327,55 @@ Example benchmark_config.json:
```
### `llama-stack-client eval run-scoring`
Run scoring from application datasets
```bash
llama-stack-client eval run-scoring <eval-task-id> --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
llama-stack-client eval run-scoring <eval-task-id> --output-dir <output-dir> [--num-examples <num>] [--visualize]
```
Required arguments:
- `--eval-task-config`: Path to the eval task config file in JSON format
- `--output-dir`: Path to the directory where scoring results will be saved
Optional arguments:
- `--num-examples`: Number of examples to evaluate (useful for debugging)
- `--visualize`: If set, visualizes scoring results after completion
- `--scoring-params-config`: Path to the scoring params config file in JSON format
- `--dataset-id`: Pre-registered dataset_id to score (from llama-stack-client datasets list)
- `--dataset-path`: Path to the dataset file to score
## Eval Tasks
Manage evaluation tasks.
### `llama-stack-client eval_tasks list`
Show available eval tasks on distribution endpoint
```bash
llama-stack-client eval_tasks list
```
### `llama-stack-client eval_tasks register`
Register a new eval task
```bash
llama-stack-client eval_tasks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <scoring-functions> [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
```
Required arguments:
- `--eval-task-id`: ID of the eval task
- `--dataset-id`: ID of the dataset to evaluate
- `--scoring-functions`: Scoring functions to use for evaluation
Optional arguments:
- `--provider-id`: Provider ID for the eval task
- `--provider-eval-task-id`: Provider's eval task ID
## Tool Group Management
Manage available tool groups.
### `llama-stack-client toolgroups list`
Show available llama toolgroups at distribution endpoint
```bash
llama-stack-client toolgroups list
```
@ -260,17 +390,28 @@ llama-stack-client toolgroups list
```
### `llama-stack-client toolgroups get`
Get available llama toolgroups by id
```bash
llama-stack-client toolgroups get <toolgroup_id>
```
Shows detailed information about a specific toolgroup. If the toolgroup is not found, displays an error message.
Required arguments:
- `TOOLGROUP_ID`: ID of the tool group
### `llama-stack-client toolgroups register`
Register a new toolgroup at distribution endpoint
```bash
llama-stack-client toolgroups register <toolgroup_id> [--provider-id <provider-id>] [--provider-toolgroup-id <provider-toolgroup-id>] [--mcp-config <mcp-config>] [--args <args>]
```
Required arguments:
- `TOOLGROUP_ID`: ID of the tool group
Optional arguments:
- `--provider-id`: Provider ID for the toolgroup
- `--provider-toolgroup-id`: Provider's toolgroup ID
@ -278,6 +419,172 @@ Optional arguments:
- `--args`: JSON arguments for the toolgroup
### `llama-stack-client toolgroups unregister`
Unregister a toolgroup from distribution endpoint
```bash
llama-stack-client toolgroups unregister <toolgroup_id>
```
Required arguments:
- `TOOLGROUP_ID`: ID of the tool group
## Datasets Management
Manage datasets.
### `llama-stack-client datasets list`
Show available datasets on distribution endpoint
```bash
llama-stack-client datasets list
```
### `llama-stack-client datasets register`
```bash
llama-stack-client datasets register --dataset_id <dataset_id> --purpose <purpose> [--url <url] [--dataset-path <dataset-path>] [--dataset-id <dataset-id>] [--metadata <metadata>]
```
Required arguments:
- `--dataset_id`: Id of the dataset
- `--purpose`: Purpose of the dataset
Optional arguments:
- `--metadata`: Metadata of the dataset
- `--url`: URL of the dataset
- `--dataset-path`: Local file path to the dataset. If specified, upload dataset via URL
### `llama-stack-client datasets unregister`
Remove a dataset
```bash
llama-stack-client datasets unregister <dataset-id>
```
Required arguments:
- `DATASET_ID`: Id of the dataset
## Scoring Functions Management
Manage scoring functions.
### `llama-stack-client scoring_functions list`
Show available scoring functions on distribution endpoint
```bash
llama-stack-client scoring_functions list
```
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃ identifier ┃ provider_id ┃ description ┃ type ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│ basic::bfcl │ basic │ BFCL complex scoring │ scoring_function │
│ basic::docvqa │ basic │ DocVQA Visual Question & Answer scoring function │ scoring_function │
│ basic::equality │ basic │ Returns 1.0 if the input is equal to the target, 0.0 │ scoring_function │
│ │ │ otherwise. │ │
└────────────────────────────────────────────┴──────────────┴───────────────────────────────────────────────────────────────┴──────────────────┘
```
### `llama-stack-client scoring_functions register`
Register a new scoring function
```bash
llama-stack-client scoring_functions register --scoring-fn-id <scoring-fn-id> --description <description> --return-type <return-type> [--provider-id <provider-id>] [--provider-scoring-fn-id <provider-scoring-fn-id>] [--params <params>]
```
Required arguments:
- `--scoring-fn-id`: Id of the scoring function
- `--description`: Description of the scoring function
- `--return-type`: Return type of the scoring function
Optional arguments:
- `--provider-id`: Provider ID for the scoring function
- `--provider-scoring-fn-id`: Provider's scoring function ID
- `--params`: Parameters for the scoring function in JSON format
## Post Training Management
Post-training.
### `llama-stack-client post_training list`
Show the list of available post training jobs
```bash
llama-stack-client post_training list
```
```bash
["job-1", "job-2", "job-3"]
```
### `llama-stack-client post_training artifacts`
Get the training artifacts of a specific post training job
```bash
llama-stack-client post_training artifacts --job-uuid <job-uuid>
```
```bash
JobArtifactsResponse(checkpoints=[], job_uuid='job-1')
```
Required arguments:
- `--job-uuid`: Job UUID
### `llama-stack-client post_training supervised_fine_tune`
Kick off a supervised fine tune job
```bash
llama-stack-client post_training supervised_fine_tune --job-uuid <job-uuid> --model <model> --algorithm-config <algorithm-config> --training-config <training-config> [--checkpoint-dir <checkpoint-dir>]
```
Required arguments:
- `--job-uuid`: Job UUID
- `--model`: Model ID
- `--algorithm-config`: Algorithm Config
- `--training-config`: Training Config
Optional arguments:
- `--checkpoint-dir`: Checkpoint Config
### `llama-stack-client post_training status`
Show the status of a specific post training job
```bash
llama-stack-client post_training status --job-uuid <job-uuid>
```
```bash
JobStatusResponse(
checkpoints=[],
job_uuid='job-1',
status='completed',
completed_at="",
resources_allocated="",
scheduled_at="",
started_at=""
)
```
Required arguments:
- `--job-uuid`: Job UUID
### `llama-stack-client post_training cancel`
Cancel the training job
```bash
llama-stack-client post_training cancel --job-uuid <job-uuid>
```
```bash
# This functionality is not yet implemented for llama-stack-client
╭────────────────────────────────────────────────────────────╮
│ Failed to post_training cancel_training_job │
│ │
│ Error Type: InternalServerError │
│ Details: Error code: 501 - {'detail': 'Not implemented: '} │
╰────────────────────────────────────────────────────────────╯
```
Required arguments:
- `--job-uuid`: Job UUID

View file

@ -19,6 +19,7 @@ class VectorDB(Resource):
embedding_model: str
embedding_dimension: int
vector_db_name: str | None = None
@property
def vector_db_id(self) -> str:
@ -70,6 +71,7 @@ class VectorDBs(Protocol):
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
"""Register a vector database.
@ -78,6 +80,7 @@ class VectorDBs(Protocol):
:param embedding_model: The embedding model to use.
:param embedding_dimension: The dimension of the embedding model.
:param provider_id: The identifier of the provider.
:param vector_db_name: The name of the vector database.
:param provider_vector_db_id: The identifier of the vector database in the provider.
:returns: A VectorDB.
"""

View file

@ -346,7 +346,6 @@ class VectorIO(Protocol):
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store.
@ -358,7 +357,6 @@ class VectorIO(Protocol):
:param embedding_model: The embedding model to use for this vector store.
:param embedding_dimension: The dimension of the embedding vectors (default: 384).
:param provider_id: The ID of the provider to use for this vector store.
:param provider_vector_db_id: The provider-specific vector database ID.
:returns: A VectorStoreObject representing the created vector store.
"""
...

View file

@ -17,7 +17,7 @@ from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.stack import replace_env_vars
from llama_stack.distribution.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
@ -164,7 +164,8 @@ def upgrade_from_routing_table(
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
return StackRunConfig(**replace_env_vars(config_dict))
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
if "routing_table" in config_dict:
logger.info("Upgrading config...")
@ -175,4 +176,5 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
if not config_dict.get("external_providers_dir", None):
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
return StackRunConfig(**replace_env_vars(config_dict))
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))

View file

@ -200,7 +200,7 @@ def validate_and_prepare_providers(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
import uuid
from typing import Any
from llama_stack.apis.common.content_types import (
@ -81,6 +82,7 @@ class VectorIORouter(VectorIO):
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
vector_db_name: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
@ -89,6 +91,7 @@ class VectorIORouter(VectorIO):
embedding_model,
embedding_dimension,
provider_id,
vector_db_name,
provider_vector_db_id,
)
@ -123,7 +126,6 @@ class VectorIORouter(VectorIO):
embedding_model: str | None = None,
embedding_dimension: int | None = None,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
@ -135,17 +137,17 @@ class VectorIORouter(VectorIO):
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
vector_db_id = name
vector_db_id = f"vs_{uuid.uuid4()}"
registered_vector_db = await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
embedding_dimension,
provider_id,
provider_vector_db_id,
vector_db_id=vector_db_id,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
provider_id=provider_id,
provider_vector_db_id=vector_db_id,
vector_db_name=name,
)
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
vector_db_id,
name=name,
file_ids=file_ids,
expires_after=expires_after,
chunking_strategy=chunking_strategy,

View file

@ -36,6 +36,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB:
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
@ -62,6 +63,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
"provider_resource_id": provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_db_name,
}
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)

View file

@ -47,6 +47,7 @@ from llama_stack.distribution.server.routes import (
initialize_route_impls,
)
from llama_stack.distribution.stack import (
cast_image_name_to_string,
construct_stack,
replace_env_vars,
validate_env_pair,
@ -439,7 +440,7 @@ def main(args: argparse.Namespace | None = None):
logger.error(f"Error: {str(e)}")
sys.exit(1)
config = replace_env_vars(config_contents)
config = StackRunConfig(**config)
config = StackRunConfig(**cast_image_name_to_string(config))
# now that the logger is initialized, print the line about which type of config we are using.
logger.info(log_line)

View file

@ -267,6 +267,13 @@ def _convert_string_to_proper_type(value: str) -> Any:
return value
def cast_image_name_to_string(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Ensure that any value for a key 'image_name' in a config_dict is a string"""
if "image_name" in config_dict and config_dict["image_name"] is not None:
config_dict["image_name"] = str(config_dict["image_name"])
return config_dict
def validate_env_pair(env_pair: str) -> tuple[str, str]:
"""Validate and split an environment variable key-value pair."""
try:

View file

@ -51,6 +51,9 @@ class LocalfsFilesImpl(Files):
},
)
async def shutdown(self) -> None:
pass
def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API."""
return f"file-{uuid.uuid4().hex}"

View file

@ -7,6 +7,7 @@
import asyncio
import json
import logging
import re
import sqlite3
import struct
from typing import Any
@ -117,6 +118,10 @@ def _rrf_rerank(
return rrf_scores
def _make_sql_identifier(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
class SQLiteVecIndex(EmbeddingIndex):
"""
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
@ -130,9 +135,9 @@ class SQLiteVecIndex(EmbeddingIndex):
self.dimension = dimension
self.db_path = db_path
self.bank_id = bank_id
self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
self.metadata_table = _make_sql_identifier(f"chunks_{bank_id}")
self.vector_table = _make_sql_identifier(f"vec_chunks_{bank_id}")
self.fts_table = _make_sql_identifier(f"fts_chunks_{bank_id}")
self.kvstore = kvstore
@classmethod
@ -148,14 +153,14 @@ class SQLiteVecIndex(EmbeddingIndex):
try:
# Create the table to store chunk metadata.
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.metadata_table} (
CREATE TABLE IF NOT EXISTS [{self.metadata_table}] (
id TEXT PRIMARY KEY,
chunk TEXT
);
""")
# Create the virtual table for embeddings.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table}
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.vector_table}]
USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
""")
connection.commit()
@ -163,7 +168,7 @@ class SQLiteVecIndex(EmbeddingIndex):
# based on query. Implementation of the change on client side will allow passing the search_mode option
# during initialization to make it easier to create the table that is required.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table}
CREATE VIRTUAL TABLE IF NOT EXISTS [{self.fts_table}]
USING fts5(id, content);
""")
connection.commit()
@ -178,9 +183,9 @@ class SQLiteVecIndex(EmbeddingIndex):
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};")
cur.execute(f"DROP TABLE IF EXISTS [{self.metadata_table}];")
cur.execute(f"DROP TABLE IF EXISTS [{self.vector_table}];")
cur.execute(f"DROP TABLE IF EXISTS [{self.fts_table}];")
connection.commit()
finally:
cur.close()
@ -212,7 +217,7 @@ class SQLiteVecIndex(EmbeddingIndex):
metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks]
cur.executemany(
f"""
INSERT INTO {self.metadata_table} (id, chunk)
INSERT INTO [{self.metadata_table}] (id, chunk)
VALUES (?, ?)
ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk;
""",
@ -230,7 +235,7 @@ class SQLiteVecIndex(EmbeddingIndex):
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
]
cur.executemany(
f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);",
f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);",
embedding_data,
)
@ -238,13 +243,13 @@ class SQLiteVecIndex(EmbeddingIndex):
fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks]
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
cur.executemany(
f"DELETE FROM {self.fts_table} WHERE id = ?;",
f"DELETE FROM [{self.fts_table}] WHERE id = ?;",
[(row[0],) for row in fts_data],
)
# INSERT new entries
cur.executemany(
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);",
f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);",
fts_data,
)
@ -280,8 +285,8 @@ class SQLiteVecIndex(EmbeddingIndex):
emb_blob = serialize_vector(emb_list)
query_sql = f"""
SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v
JOIN {self.metadata_table} AS m ON m.id = v.id
FROM [{self.vector_table}] AS v
JOIN [{self.metadata_table}] AS m ON m.id = v.id
WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance;
"""
@ -322,9 +327,9 @@ class SQLiteVecIndex(EmbeddingIndex):
cur = connection.cursor()
try:
query_sql = f"""
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
FROM {self.fts_table} AS f
JOIN {self.metadata_table} AS m ON m.id = f.id
SELECT DISTINCT m.id, m.chunk, bm25([{self.fts_table}]) AS score
FROM [{self.fts_table}] AS f
JOIN [{self.metadata_table}] AS m ON m.id = f.id
WHERE f.content MATCH ?
ORDER BY score ASC
LIMIT ?;

View file

@ -3,16 +3,17 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
LlamaCompatConfig,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: LlamaCompatConfig
@ -27,8 +28,32 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
)
self.config = config
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from Llama API.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
llama_api_client = self._get_llama_api_client()
retrieved_model = await llama_api_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from Llama API")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from Llama API")
return False
except Exception as e:
logger.error(f"Failed to check model availability from Llama API: {e}")
return False
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)

View file

@ -7,7 +7,6 @@
import logging
import warnings
from collections.abc import AsyncIterator
from functools import lru_cache
from typing import Any
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
@ -98,41 +97,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# If we can't retrieve the model, it's not available
return False
@lru_cache # noqa: B019
def _get_client(self, provider_model_id: str | None = None) -> AsyncOpenAI:
@property
def _client(self) -> AsyncOpenAI:
"""
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
some models are hosted on different URLs. This function returns the appropriate client
for the given provider_model_id.
Returns an OpenAI client for the configured NVIDIA API endpoint.
This relies on lru_cache and self._default_client to avoid creating a new client for each request
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
:param provider_model_id: The provider model ID (optional, defaults to primary endpoint)
:return: An OpenAI client
"""
@lru_cache # noqa: B019
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
"""
Maintain a single OpenAI client per base_url.
"""
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
special_model_urls = {
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
if provider_model_id and _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
return _get_client_for_base_url(base_url)
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store:
@ -174,7 +153,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._get_client(provider_model_id).completions.create(**request)
response = await self._client.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -227,7 +206,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self._get_client(provider_model_id).embeddings.create(
response = await self._client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
@ -288,7 +267,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._get_client(provider_model_id).chat.completions.create(**request)
response = await self._client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -344,7 +323,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
return await self._get_client(provider_model_id).completions.create(**params)
return await self._client.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -403,6 +382,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
return await self._get_client(provider_model_id).chat.completions.create(**params)
return await self._client.chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e

View file

@ -8,7 +8,7 @@ import logging
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI
from openai import AsyncOpenAI, NotFoundError
from llama_stack.apis.inference import (
OpenAIChatCompletion,
@ -60,6 +60,27 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from OpenAI.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
openai_client = self._get_openai_client()
retrieved_model = await openai_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from OpenAI")
return False
except Exception as e:
logger.error(f"Failed to check model availability from OpenAI: {e}")
return False
async def initialize(self) -> None:
await super().initialize()

View file

@ -217,7 +217,6 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
uri: str = Field(description="The URI of the Milvus server")
token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
# This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
return {
"uri": "${env.MILVUS_ENDPOINT}",
"token": "${env.MILVUS_TOKEN}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="milvus_remote_registry.db",
),
}

View file

@ -12,7 +12,7 @@ import re
from typing import Any
from numpy.typing import NDArray
from pymilvus import DataType, MilvusClient
from pymilvus import DataType, Function, FunctionType, MilvusClient
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
@ -74,12 +74,66 @@ class MilvusIndex(EmbeddingIndex):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
schema.add_field(
field_name="chunk_id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=100,
)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True, # Enable text analysis for BM25
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=len(embeddings[0]),
)
schema.add_field(
field_name="chunk_content",
datatype=DataType.JSON,
)
# Add sparse vector field for BM25 (required by the function)
schema.add_field(
field_name="sparse",
datatype=DataType.SPARSE_FLOAT_VECTOR,
)
# Create indexes
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="FLAT",
metric_type="COSINE",
)
# Add index for sparse field (required by BM25 function)
index_params.add_index(
field_name="sparse",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
# Add BM25 function for full-text search
bm25_function = Function(
name="text_bm25_emb",
input_field_names=["content"],
output_field_names=["sparse"],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
await asyncio.to_thread(
self.client.create_collection,
self.collection_name,
dimension=len(embeddings[0]),
auto_id=True,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
@ -88,8 +142,10 @@ class MilvusIndex(EmbeddingIndex):
data.append(
{
"chunk_id": chunk.chunk_id,
"content": chunk.content,
"vector": embedding,
"chunk_content": chunk.model_dump(),
# sparse field will be handled by BM25 function automatically
}
)
try:
@ -107,6 +163,7 @@ class MilvusIndex(EmbeddingIndex):
self.client.search,
collection_name=self.collection_name,
data=[embedding],
anns_field="vector",
limit=k,
output_fields=["*"],
search_params={"params": {"radius": score_threshold}},
@ -121,7 +178,64 @@ class MilvusIndex(EmbeddingIndex):
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus")
"""
Perform BM25-based keyword search using Milvus's built-in full-text search.
"""
try:
# Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,
data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25
output_fields=["chunk_content"], # Output the chunk content
limit=k,
search_params={
"params": {
"drop_ratio_search": 0.2, # Ignore low-importance terms
}
},
)
chunks = []
scores = []
for res in search_res[0]:
chunk = Chunk(**res["entity"]["chunk_content"])
chunks.append(chunk)
scores.append(res["distance"]) # BM25 score from Milvus
# Filter by score threshold
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
filtered_scores = [score for score in scores if score >= score_threshold]
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
except Exception as e:
logger.error(f"Error performing BM25 search: {e}")
# Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold)
async def _fallback_keyword_search(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
collection_name=self.collection_name,
filter='content like "%{content}%"',
filter_params={"content": query_string},
output_fields=["*"],
limit=k,
)
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
scores = [1.0] * len(chunks) # Simple binary score for text search
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid(
self,
@ -247,6 +361,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
if params and params.get("mode") == "keyword":
# Check if this is inline Milvus (Milvus-Lite)
if hasattr(self.config, "db_path"):
raise NotImplementedError(
"Keyword search is not supported in Milvus-Lite. "
"Please use a remote Milvus server for keyword search functionality."
)
return await index.query_chunks(query, params)
async def _save_openai_vector_store_file(

View file

@ -218,9 +218,6 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
async def register_vector_db(self, vector_db: VectorDB) -> None:
# Persist vector DB metadata in the KV store
assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
# Upsert model metadata in Postgres
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
@ -273,16 +270,120 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
"""Save vector store file metadata to Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS openai_vector_store_files (
store_id TEXT,
file_id TEXT,
metadata JSONB,
PRIMARY KEY (store_id, file_id)
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS openai_vector_store_files_contents (
store_id TEXT,
file_id TEXT,
contents JSONB,
PRIMARY KEY (store_id, file_id)
)
"""
)
# Insert file metadata
files_query = sql.SQL(
"""
INSERT INTO openai_vector_store_files (store_id, file_id, metadata)
VALUES %s
ON CONFLICT (store_id, file_id) DO UPDATE SET metadata = EXCLUDED.metadata
"""
)
files_values = [(store_id, file_id, Json(file_info))]
execute_values(cur, files_query, files_values, template="(%s, %s, %s)")
# Insert file contents
contents_query = sql.SQL(
"""
INSERT INTO openai_vector_store_files_contents (store_id, file_id, contents)
VALUES %s
ON CONFLICT (store_id, file_id) DO UPDATE SET contents = EXCLUDED.contents
"""
)
contents_values = [(store_id, file_id, Json(file_contents))]
execute_values(cur, contents_query, contents_values, template="(%s, %s, %s)")
except Exception as e:
log.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}")
raise
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
"""Load vector store file metadata from Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"SELECT metadata FROM openai_vector_store_files WHERE store_id = %s AND file_id = %s",
(store_id, file_id),
)
row = cur.fetchone()
return row[0] if row and row[0] is not None else {}
except Exception as e:
log.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}")
return {}
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
"""Load vector store file contents from Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"SELECT contents FROM openai_vector_store_files_contents WHERE store_id = %s AND file_id = %s",
(store_id, file_id),
)
row = cur.fetchone()
return row[0] if row and row[0] is not None else []
except Exception as e:
log.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}")
return []
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
"""Update vector store file metadata in Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
query = sql.SQL(
"""
INSERT INTO openai_vector_store_files (store_id, file_id, metadata)
VALUES %s
ON CONFLICT (store_id, file_id) DO UPDATE SET metadata = EXCLUDED.metadata
"""
)
values = [(store_id, file_id, Json(file_info))]
execute_values(cur, query, values, template="(%s, %s, %s)")
except Exception as e:
log.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
raise
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
"""Delete vector store file metadata from Postgres database."""
if self.conn is None:
raise RuntimeError("PostgreSQL connection is not initialized")
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"DELETE FROM openai_vector_store_files WHERE store_id = %s AND file_id = %s",
(store_id, file_id),
)
cur.execute(
"DELETE FROM openai_vector_store_files_contents WHERE store_id = %s AND file_id = %s",
(store_id, file_id),
)
except Exception as e:
log.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}")
raise

View file

@ -214,7 +214,6 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -13,7 +13,6 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -39,7 +38,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
@ -90,12 +88,6 @@ class LiteLLMOpenAIMixin(
async def shutdown(self):
pass
async def register_model(self, model: Model) -> Model:
model_id = self.get_provider_model_id(model.provider_resource_id)
if model_id is None:
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
return model
def get_litellm_model_name(self, model_id: str) -> str:
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
# model_id.startswith("openai/") is for backwards compatibility.

View file

@ -172,8 +172,9 @@ class OpenAIVectorStoreMixin(ABC):
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store."""
store_id = name or str(uuid.uuid4())
created_at = int(time.time())
# Derive the canonical vector_db_id (allow override, else generate)
vector_db_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
if provider_id is None:
raise ValueError("Provider ID is required")
@ -181,19 +182,19 @@ class OpenAIVectorStoreMixin(ABC):
if embedding_model is None:
raise ValueError("Embedding model is required")
# Use provided embedding dimension or default to 384
# Embedding dimension is required (defaulted to 384 if not provided)
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
provider_vector_db_id = provider_vector_db_id or store_id
# Register the VectorDB backing this vector store
vector_db = VectorDB(
identifier=store_id,
identifier=vector_db_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=provider_vector_db_id,
provider_resource_id=vector_db_id,
vector_db_name=name,
)
# Register the vector DB
await self.register_vector_db(vector_db)
# Create OpenAI vector store metadata
@ -207,11 +208,11 @@ class OpenAIVectorStoreMixin(ABC):
in_progress=0,
total=0,
)
store_info = {
"id": store_id,
store_info: dict[str, Any] = {
"id": vector_db_id,
"object": "vector_store",
"created_at": created_at,
"name": store_id,
"name": name,
"usage_bytes": 0,
"file_counts": file_counts.model_dump(),
"status": status,
@ -231,18 +232,18 @@ class OpenAIVectorStoreMixin(ABC):
store_info["metadata"] = metadata
# Save to persistent storage (provider-specific)
await self._save_openai_vector_store(store_id, store_info)
await self._save_openai_vector_store(vector_db_id, store_info)
# Store in memory cache
self.openai_vector_stores[store_id] = store_info
self.openai_vector_stores[vector_db_id] = store_info
# Now that our vector store is created, attach any files that were provided
file_ids = file_ids or []
tasks = [self.openai_attach_file_to_vector_store(store_id, file_id) for file_id in file_ids]
tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids]
await asyncio.gather(*tasks)
# Get the updated store info and return it
store_info = self.openai_vector_stores[store_id]
store_info = self.openai_vector_stores[vector_db_id]
return VectorStoreObject.model_validate(store_info)
async def openai_list_vector_stores(

View file

@ -20,7 +20,7 @@
"@radix-ui/react-tooltip": "^1.2.6",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"llama-stack-client": "^0.2.14",
"llama-stack-client": "^0.2.15",
"lucide-react": "^0.510.0",
"next": "15.3.3",
"next-auth": "^4.24.11",

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "llama_stack"
version = "0.2.14"
version = "0.2.15"
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
description = "Llama Stack"
readme = "README.md"
@ -28,7 +28,8 @@ dependencies = [
"huggingface-hub>=0.30.0,<1.0",
"jinja2>=3.1.6",
"jsonschema",
"llama-stack-client>=0.2.14",
"llama-stack-client>=0.2.15",
"llama-api-client>=0.1.2",
"openai>=1.66",
"prompt-toolkit",
"python-dotenv",
@ -52,7 +53,7 @@ dependencies = [
ui = [
"streamlit",
"pandas",
"llama-stack-client>=0.2.14",
"llama-stack-client>=0.2.15",
"streamlit-option-menu",
]
@ -125,6 +126,7 @@ docs = [
"sphinxcontrib.redoc",
"sphinxcontrib.video",
"sphinxcontrib.mermaid",
"sphinx-reredirects",
"tomli",
"linkify",
"sphinxcontrib.openapi",

View file

@ -13,6 +13,7 @@ annotated-types==0.7.0
anyio==4.8.0
# via
# httpx
# llama-api-client
# llama-stack-client
# openai
# starlette
@ -49,6 +50,7 @@ deprecated==1.2.18
# opentelemetry-semantic-conventions
distro==1.9.0
# via
# llama-api-client
# llama-stack-client
# openai
ecdsa==0.19.1
@ -80,6 +82,7 @@ httpcore==1.0.9
# via httpx
httpx==0.28.1
# via
# llama-api-client
# llama-stack
# llama-stack-client
# openai
@ -101,7 +104,9 @@ jsonschema==4.23.0
# via llama-stack
jsonschema-specifications==2024.10.1
# via jsonschema
llama-stack-client==0.2.14
llama-api-client==0.1.2
# via llama-stack
llama-stack-client==0.2.15
# via llama-stack
markdown-it-py==3.0.0
# via rich
@ -165,6 +170,7 @@ pycparser==2.22 ; platform_python_implementation != 'PyPy'
pydantic==2.10.6
# via
# fastapi
# llama-api-client
# llama-stack
# llama-stack-client
# openai
@ -215,6 +221,7 @@ six==1.17.0
sniffio==1.3.1
# via
# anyio
# llama-api-client
# llama-stack-client
# openai
starlette==0.45.3
@ -239,6 +246,7 @@ typing-extensions==4.12.2
# anyio
# fastapi
# huggingface-hub
# llama-api-client
# llama-stack-client
# openai
# opentelemetry-sdk

View file

@ -31,7 +31,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models):
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
for p in vector_io_providers:
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus"]:
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::pgvector"]:
return
pytest.skip("OpenAI vector stores are not supported by any provider")
@ -821,6 +821,59 @@ def test_openai_vector_store_update_file(compat_client_with_empty_stores, client
assert retrieved_file.attributes["foo"] == "baz"
def test_create_vector_store_files_duplicate_vector_store_name(compat_client_with_empty_stores, client_with_models):
"""
This test confirms that client.vector_stores.create() creates a unique ID
"""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
pytest.skip("Vector Store Files create is not yet supported with LlamaStackClient")
compat_client = compat_client_with_empty_stores
# Create a vector store with files
file_ids = []
for i in range(3):
with BytesIO(f"This is a test file {i}".encode()) as file_buffer:
file_buffer.name = f"openai_test_{i}.txt"
file = compat_client.files.create(file=file_buffer, purpose="assistants")
file_ids.append(file.id)
vector_store = compat_client.vector_stores.create(
name="test_store_with_files",
)
assert vector_store.file_counts.completed == 0
assert vector_store.file_counts.total == 0
assert vector_store.file_counts.cancelled == 0
assert vector_store.file_counts.failed == 0
assert vector_store.file_counts.in_progress == 0
vector_store2 = compat_client.vector_stores.create(
name="test_store_with_files",
)
vector_stores_list = compat_client.vector_stores.list()
assert len(vector_stores_list.data) == 2
created_file = compat_client.vector_stores.files.create(
vector_store_id=vector_store.id,
file_id=file_ids[0],
)
assert created_file.status == "completed"
_ = compat_client.vector_stores.delete(vector_store2.id)
created_file_from_non_deleted_vector_store = compat_client.vector_stores.files.create(
vector_store_id=vector_store.id,
file_id=file_ids[1],
)
assert created_file_from_non_deleted_vector_store.status == "completed"
vector_stores_list_post_delete = compat_client.vector_stores.list()
assert len(vector_stores_list_post_delete.data) == 1
@pytest.mark.skip(reason="Client library needs to be scaffolded to support search_mode parameter")
def test_openai_vector_store_search_modes():
"""Test OpenAI vector store search with different search modes.

View file

@ -15,6 +15,37 @@ from llama_stack.distribution.configure import (
)
@pytest.fixture
def config_with_image_name_int():
return yaml.safe_load(
f"""
version: {LLAMA_STACK_RUN_CONFIG_VERSION}
image_name: 1234
apis_to_serve: []
built_at: {datetime.now().isoformat()}
providers:
inference:
- provider_id: provider1
provider_type: inline::meta-reference
config: {{}}
safety:
- provider_id: provider1
provider_type: inline::meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
enable_prompt_guard: false
memory:
- provider_id: provider1
provider_type: inline::meta-reference
config: {{}}
"""
)
@pytest.fixture
def up_to_date_config():
return yaml.safe_load(
@ -125,3 +156,8 @@ def test_parse_and_maybe_upgrade_config_old_format(old_config):
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
with pytest.raises(KeyError):
parse_and_maybe_upgrade_config(invalid_config)
def test_parse_and_maybe_upgrade_config_image_name_int(config_with_image_name_int):
result = parse_and_maybe_upgrade_config(config_with_image_name_int)
assert isinstance(result.image_name, str)

View file

@ -54,7 +54,8 @@ class TestNvidiaPostTraining(unittest.TestCase):
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
self.inference_mock_make_request = self.mock_client.chat.completions.create
self.inference_make_request_patcher = patch(
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._client",
new_callable=unittest.mock.PropertyMock,
return_value=self.mock_client,
)
self.inference_make_request_patcher.start()

View file

@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import pytest_asyncio
from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock
# Apply the mock before importing MilvusIndex
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_milvus.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
MILVUS_PROVIDER = "milvus"
@pytest_asyncio.fixture
async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
# Mock collection operations
client.has_collection.return_value = False # Initially no collection
client.create_collection.return_value = None
client.drop_collection.return_value = None
# Mock insert operation
client.insert.return_value = {"insert_count": 10}
# Mock search operation - return mock results (data should be dict, not JSON string)
client.search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Mock query operation for keyword search (data should be dict, not JSON string)
client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
return client
@pytest_asyncio.fixture
async def milvus_index(mock_milvus_client):
"""Create a MilvusIndex with mocked client."""
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
yield index
# No real cleanup needed since we're using mocks
@pytest.mark.asyncio
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Verify collection was created and data was inserted
mock_milvus_client.create_collection.assert_called_once()
mock_milvus_client.insert.assert_called_once()
# Verify the insert call had the right number of chunks
insert_call = mock_milvus_client.insert.call_args
assert len(insert_call[1]["data"]) == len(sample_chunks)
@pytest.mark.asyncio
async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
# Setup: Add chunks first
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
mock_milvus_client.search.assert_called_once()
@pytest.mark.asyncio
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
query_string = "Sentence 5"
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Force BM25 search to fail
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
# Mock simple text search results
mock_milvus_client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
},
]
# Test keyword search that should fall back to simple text search
query_string = "Python"
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
# Verify response structure
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0, "Fallback search should return results"
# Verify that simple text search was used (query method called instead of search)
mock_milvus_client.query.assert_called_once()
mock_milvus_client.search.assert_called_once() # Called once but failed
# Verify the query uses parameterized filter with filter_params
query_call_args = mock_milvus_client.query.call_args
assert "filter" in query_call_args[1], "Query should include filter for text search"
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
# Verify all returned chunks have score 1.0 (simple binary scoring)
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
@pytest.mark.asyncio
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.has_collection.return_value = True
await milvus_index.delete()
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)

View file

@ -37,7 +37,7 @@ def loop():
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db")
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank.123")
yield index
await index.delete()
@ -110,7 +110,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
cur = connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
cur.execute(f"SELECT id FROM [{sqlite_vec_index.metadata_table}]")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
connection.close()

3075
uv.lock generated

File diff suppressed because it is too large Load diff