Merge branch 'refs/heads/main' into preprocessors

# Conflicts:
#	llama_stack/distribution/routers/routers.py
#	llama_stack/templates/ollama/build.yaml
#	llama_stack/templates/ollama/run-with-safety.yaml
#	llama_stack/templates/ollama/run.yaml
#	llama_stack/templates/remote-vllm/build.yaml
#	llama_stack/templates/remote-vllm/run-with-safety.yaml
#	llama_stack/templates/remote-vllm/run.yaml
#	llama_stack/templates/together/build.yaml
#	llama_stack/templates/together/run-with-safety.yaml
#	llama_stack/templates/together/run.yaml
This commit is contained in:
ilya-kolchinsky 2025-03-07 16:20:30 +01:00
commit 6b9f673fdb
313 changed files with 181388 additions and 7064 deletions

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in # These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence, # the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan * @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan @SLR722

36
.github/workflows/unit-tests.yml vendored Normal file
View file

@ -0,0 +1,36 @@
name: Unit Tests
on:
pull_request:
branches: [ main ]
workflow_dispatch:
jobs:
unit-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10.16'
- uses: astral-sh/setup-uv@v5
with:
python-version: '3.10.16'
enable-cache: false
- name: Run unit tests
run: |
uv run -p 3.10.16 --with . --with ".[dev]" --with ".[test]" pytest -s -v tests/unit/ --junitxml=pytest-report.xml
- name: Upload test results
if: always()
uses: actions/upload-artifact@v4
with:
name: test-results
path: |
.pytest_cache/
pytest-report.xml
retention-days: 7

View file

@ -8,6 +8,8 @@ repos:
rev: v5.0.0 # Latest stable version rev: v5.0.0 # Latest stable version
hooks: hooks:
- id: check-merge-conflict - id: check-merge-conflict
- id: trailing-whitespace
exclude: '\.py$' # Exclude Python files as Ruff already handles them
- id: check-added-large-files - id: check-added-large-files
args: ['--maxkb=1000'] args: ['--maxkb=1000']
- id: end-of-file-fixer - id: end-of-file-fixer
@ -83,10 +85,8 @@ repos:
- id: distro-codegen - id: distro-codegen
name: Distribution Template Codegen name: Distribution Template Codegen
additional_dependencies: additional_dependencies:
- rich
- pydantic
- uv==0.6.0 - uv==0.6.0
entry: uv run python -m llama_stack.scripts.distro_codegen entry: uv run --extra codegen python -m llama_stack.scripts.distro_codegen
language: python language: python
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true

1242
CHANGELOG.md Normal file

File diff suppressed because it is too large Load diff

View file

@ -64,10 +64,10 @@ You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting
You can install the dependencies by running: You can install the dependencies by running:
```bash ```bash
$ cd llama-stack cd llama-stack
$ uv sync --extra dev uv sync --extra dev
$ uv pip install -e . uv pip install -e .
$ source .venv/bin/activate source .venv/bin/activate
``` ```
Note that you can create a dotenv file `.env` that includes necessary environment variables: Note that you can create a dotenv file `.env` that includes necessary environment variables:
@ -80,7 +80,7 @@ LLAMA_STACK_CONFIG=
And then use this dotenv file when running client SDK tests via the following: And then use this dotenv file when running client SDK tests via the following:
```bash ```bash
$ uv run --env-file .env -- pytest -v tests/client-sdk/inference/test_text_inference.py uv run --env-file .env -- pytest -v tests/api/inference/test_text_inference.py
``` ```
## Pre-commit Hooks ## Pre-commit Hooks
@ -88,7 +88,7 @@ $ uv run --env-file .env -- pytest -v tests/client-sdk/inference/test_text_infer
We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running: We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running:
```bash ```bash
$ uv run pre-commit install uv run pre-commit install
``` ```
After that, pre-commit hooks will run automatically before each commit. After that, pre-commit hooks will run automatically before each commit.
@ -96,7 +96,7 @@ After that, pre-commit hooks will run automatically before each commit.
Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running: Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running:
```bash ```bash
$ uv run pre-commit run --all-files uv run pre-commit run --all-files
``` ```
> [!CAUTION] > [!CAUTION]
@ -107,8 +107,8 @@ $ uv run pre-commit run --all-files
To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run: To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run:
```bash ```bash
$ uv add foo uv add foo
$ uv sync uv sync
``` ```
## Coding Style ## Coding Style
@ -123,15 +123,15 @@ Some tips about common tasks you work on while contributing to Llama Stack:
### Using `llama stack build` ### Using `llama stack build`
Building a stack image (conda / docker) will use the production version of the `llama-stack`, `llama-models` and `llama-stack-client` packages. If you are developing with a llama-stack repository checked out and need your code to be reflected in the stack image, set `LLAMA_STACK_DIR` and `LLAMA_MODELS_DIR` to the appropriate checked out directories when running any of the `llama` CLI commands. Building a stack image (conda / docker) will use the production version of the `llama-stack` and `llama-stack-client` packages. If you are developing with a llama-stack repository checked out and need your code to be reflected in the stack image, set `LLAMA_STACK_DIR` and `LLAMA_STACK_CLIENT_DIR` to the appropriate checked out directories when running any of the `llama` CLI commands.
Example: Example:
```bash ```bash
$ cd work/ cd work/
$ git clone https://github.com/meta-llama/llama-stack.git git clone https://github.com/meta-llama/llama-stack.git
$ git clone https://github.com/meta-llama/llama-models.git git clone https://github.com/meta-llama/llama-stack-client-python.git
$ cd llama-stack cd llama-stack
$ LLAMA_STACK_DIR=$(pwd) LLAMA_MODELS_DIR=../llama-models llama stack build --template <...> LLAMA_STACK_DIR=$(pwd) LLAMA_STACK_CLIENT_DIR=../llama-stack-client-python llama stack build --template <...>
``` ```
@ -144,14 +144,14 @@ If you have made changes to a provider's configuration in any form (introducing
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme. If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
```bash ```bash
$ cd llama-stack/docs cd llama-stack/docs
$ uv sync --extra docs uv sync --extra docs
# This rebuilds the documentation pages. # This rebuilds the documentation pages.
$ uv run make html uv run make html
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation. # This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
$ uv run sphinx-autobuild source build/html --write-all uv run sphinx-autobuild source build/html --write-all
``` ```
### Update API Documentation ### Update API Documentation
@ -159,8 +159,8 @@ $ uv run sphinx-autobuild source build/html --write-all
If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command: If you modify or add new API endpoints, update the API documentation accordingly. You can do this by running the following command:
```bash ```bash
$ uv sync --extra dev uv sync --extra dev
$ uv run ./docs/openapi_generator/run_openapi_generator.sh uv run ./docs/openapi_generator/run_openapi_generator.sh
``` ```
The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing. The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing.

View file

@ -1,5 +1,6 @@
include pyproject.toml include pyproject.toml
include distributions/dependencies.json include distributions/dependencies.json
include llama_stack/models/llama/llama3/tokenizer.model
include llama_stack/distribution/*.sh include llama_stack/distribution/*.sh
include llama_stack/cli/scripts/*.sh include llama_stack/cli/scripts/*.sh
include llama_stack/templates/*/*.yaml include llama_stack/templates/*/*.yaml

View file

@ -32,7 +32,7 @@ Llama Stack standardizes the core building blocks that simplify AI application d
By reducing friction and complexity, Llama Stack empowers developers to focus on what they do best: building transformative generative AI applications. By reducing friction and complexity, Llama Stack empowers developers to focus on what they do best: building transformative generative AI applications.
### API Providers ### API Providers
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack. Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:| |:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -141,7 +141,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 18,
"id": "E1UFuJC570Tk", "id": "E1UFuJC570Tk",
"metadata": { "metadata": {
"colab": { "colab": {
@ -326,54 +326,108 @@
" type: sqlite\n", " type: sqlite\n",
"models:\n", "models:\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct-Turbo\n", " provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-8B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct-Turbo\n", " provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-70B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-FP8\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-FP8\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-Turbo\n", " provider_model_id: meta-llama/Meta-Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.1</span>-405B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-3B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-11B-Vision-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.2</span>-90B-Vision-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct\n", " model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3.3</span>-70B-Instruct-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Meta-Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n", " model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n", " provider_model_id: meta-llama/Meta-Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-8B\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n", "- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision-Turbo\n",
"- metadata: <span style=\"font-weight: bold\">{}</span>\n",
" model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision\n", " model_id: meta-llama/Llama-Guard-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span>-11B-Vision\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
@ -473,6 +527,9 @@
" - config: <span style=\"font-weight: bold\">{}</span>\n", " - config: <span style=\"font-weight: bold\">{}</span>\n",
" provider_id: model-context-protocol\n", " provider_id: model-context-protocol\n",
" provider_type: remote::model-context-protocol\n", " provider_type: remote::model-context-protocol\n",
" - config: <span style=\"font-weight: bold\">{}</span>\n",
" provider_id: wolfram-alpha\n",
" provider_type: remote::wolfram-alpha\n",
" vector_io:\n", " vector_io:\n",
" - config:\n", " - config:\n",
" kvstore:\n", " kvstore:\n",
@ -504,6 +561,10 @@
" mcp_endpoint: null\n", " mcp_endpoint: null\n",
" provider_id: code-interpreter\n", " provider_id: code-interpreter\n",
" toolgroup_id: builtin::code_interpreter\n", " toolgroup_id: builtin::code_interpreter\n",
"- args: null\n",
" mcp_endpoint: null\n",
" provider_id: wolfram-alpha\n",
" toolgroup_id: builtin::wolfram_alpha\n",
"vector_dbs: <span style=\"font-weight: bold\">[]</span>\n", "vector_dbs: <span style=\"font-weight: bold\">[]</span>\n",
"version: <span style=\"color: #008000; text-decoration-color: #008000\">'2'</span>\n", "version: <span style=\"color: #008000; text-decoration-color: #008000\">'2'</span>\n",
"\n", "\n",
@ -530,54 +591,108 @@
" type: sqlite\n", " type: sqlite\n",
"models:\n", "models:\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct\n", " model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct-Turbo\n", " provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-8B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct\n", " model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct-Turbo\n", " provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-70B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-FP8\n", " model_id: meta-llama/Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-FP8\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-Turbo\n", " provider_model_id: meta-llama/Meta-Llama-\u001b[1;36m3.1\u001b[0m-405B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct\n", " model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-3B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct\n", " model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-11B-Vision-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct\n", " model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-\u001b[1;36m3.2\u001b[0m-90B-Vision-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct\n", " model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct-Turbo\n", " provider_model_id: meta-llama/Llama-\u001b[1;36m3.3\u001b[0m-70B-Instruct-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Meta-Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n", " model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
" provider_id: together\n", " provider_id: together\n",
" provider_model_id: meta-llama/Meta-Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n", " provider_model_id: meta-llama/Meta-Llama-Guard-\u001b[1;36m3\u001b[0m-8B\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision-Turbo\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n",
" provider_id: together\n",
" provider_model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision-Turbo\n",
"- metadata: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision\n", " model_id: meta-llama/Llama-Guard-\u001b[1;36m3\u001b[0m-11B-Vision\n",
" model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n", " model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType\n",
" - llm\n", " - llm\n",
@ -677,6 +792,9 @@
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" provider_id: model-context-protocol\n", " provider_id: model-context-protocol\n",
" provider_type: remote::model-context-protocol\n", " provider_type: remote::model-context-protocol\n",
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" provider_id: wolfram-alpha\n",
" provider_type: remote::wolfram-alpha\n",
" vector_io:\n", " vector_io:\n",
" - config:\n", " - config:\n",
" kvstore:\n", " kvstore:\n",
@ -708,6 +826,10 @@
" mcp_endpoint: null\n", " mcp_endpoint: null\n",
" provider_id: code-interpreter\n", " provider_id: code-interpreter\n",
" toolgroup_id: builtin::code_interpreter\n", " toolgroup_id: builtin::code_interpreter\n",
"- args: null\n",
" mcp_endpoint: null\n",
" provider_id: wolfram-alpha\n",
" toolgroup_id: builtin::wolfram_alpha\n",
"vector_dbs: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "vector_dbs: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
"version: \u001b[32m'2'\u001b[0m\n", "version: \u001b[32m'2'\u001b[0m\n",
"\n" "\n"
@ -1513,18 +1635,14 @@
"source": [ "source": [
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"agent_config = AgentConfig(\n", "agent = Agent(\n",
" client, \n",
" model=model_id,\n", " model=model_id,\n",
" instructions=\"You are a helpful assistant\",\n", " instructions=\"You are a helpful assistant\",\n",
" toolgroups=[\"builtin::websearch\"],\n", " tools=[\"builtin::websearch\"],\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
")\n", ")\n",
"agent = Agent(client, agent_config)\n",
"user_prompts = [\n", "user_prompts = [\n",
" \"Hello\",\n", " \"Hello\",\n",
" \"Which teams played in the NBA western conference finals of 2024\",\n", " \"Which teams played in the NBA western conference finals of 2024\",\n",
@ -1693,7 +1811,6 @@
"import uuid\n", "import uuid\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"from llama_stack_client.types import Document\n", "from llama_stack_client.types import Document\n",
"\n", "\n",
@ -1719,11 +1836,11 @@
" vector_db_id=vector_db_id,\n", " vector_db_id=vector_db_id,\n",
" chunk_size_in_tokens=512,\n", " chunk_size_in_tokens=512,\n",
")\n", ")\n",
"agent_config = AgentConfig(\n", "rag_agent = Agent(\n",
" client, \n",
" model=model_id,\n", " model=model_id,\n",
" instructions=\"You are a helpful assistant\",\n", " instructions=\"You are a helpful assistant\",\n",
" enable_session_persistence=False,\n", " tools = [\n",
" toolgroups = [\n",
" {\n", " {\n",
" \"name\": \"builtin::rag/knowledge_search\",\n", " \"name\": \"builtin::rag/knowledge_search\",\n",
" \"args\" : {\n", " \"args\" : {\n",
@ -1732,7 +1849,6 @@
" }\n", " }\n",
" ],\n", " ],\n",
")\n", ")\n",
"rag_agent = Agent(client, agent_config)\n",
"session_id = rag_agent.create_session(\"test-session\")\n", "session_id = rag_agent.create_session(\"test-session\")\n",
"user_prompts = [\n", "user_prompts = [\n",
" \"What are the top 5 topics that were explained? Only list succinct bullet points.\",\n", " \"What are the top 5 topics that were explained? Only list succinct bullet points.\",\n",
@ -1856,23 +1972,19 @@
"source": [ "source": [
"from llama_stack_client.types.agents.turn_create_params import Document\n", "from llama_stack_client.types.agents.turn_create_params import Document\n",
"\n", "\n",
"agent_config = AgentConfig(\n", "codex_agent = Agent(\n",
" client, \n",
" model=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
" instructions=\"You are a helpful assistant\",\n",
" tools=[\n",
" \"builtin::code_interpreter\",\n",
" \"builtin::websearch\"\n",
" ],\n",
" sampling_params = {\n", " sampling_params = {\n",
" \"max_tokens\" : 4096,\n", " \"max_tokens\" : 4096,\n",
" \"temperature\": 0.0\n", " \"temperature\": 0.0\n",
" },\n", " },\n",
" model=\"meta-llama/Llama-3.1-8B-Instruct\",\n",
" instructions=\"You are a helpful assistant\",\n",
" toolgroups=[\n",
" \"builtin::code_interpreter\",\n",
" \"builtin::websearch\"\n",
" ],\n",
" tool_choice=\"auto\",\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
")\n", ")\n",
"codex_agent = Agent(client, agent_config)\n",
"session_id = codex_agent.create_session(\"test-session\")\n", "session_id = codex_agent.create_session(\"test-session\")\n",
"\n", "\n",
"\n", "\n",
@ -2782,18 +2894,14 @@
"# NBVAL_SKIP\n", "# NBVAL_SKIP\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"from termcolor import cprint\n", "from termcolor import cprint\n",
"\n", "\n",
"agent_config = AgentConfig(\n", "agent = Agent(\n",
" client, \n",
" model=model_id,\n", " model=model_id,\n",
" instructions=\"You are a helpful assistant\",\n", " instructions=\"You are a helpful assistant\",\n",
" toolgroups=[\"mcp::filesystem\"],\n", " tools=[\"mcp::filesystem\"],\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
")\n", ")\n",
"agent = Agent(client, agent_config)\n",
"user_prompts = [\n", "user_prompts = [\n",
" \"Hello\",\n", " \"Hello\",\n",
" \"list all the files /content\",\n", " \"list all the files /content\",\n",
@ -2888,17 +2996,13 @@
"source": [ "source": [
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"\n", "\n",
"agent_config = AgentConfig(\n", "agent = Agent(\n",
" client, \n",
" model=\"meta-llama/Llama-3.3-70B-Instruct\",\n", " model=\"meta-llama/Llama-3.3-70B-Instruct\",\n",
" instructions=\"You are a helpful assistant. Use search tool to answer the questions. \",\n", " instructions=\"You are a helpful assistant. Use search tool to answer the questions. \",\n",
" toolgroups=[\"builtin::websearch\"],\n", " tools=[\"builtin::websearch\"],\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
")\n", ")\n",
"agent = Agent(client, agent_config)\n",
"user_prompts = [\n", "user_prompts = [\n",
" \"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.\",\n", " \"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.\",\n",
" \"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.\",\n", " \"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.\",\n",
@ -4098,7 +4202,7 @@
"source": [ "source": [
"## 4. Image Understanding with Llama 3.2\n", "## 4. Image Understanding with Llama 3.2\n",
"\n", "\n",
"Below is a complete example of using Together's Llama Stack 0.1 server at https://llama-stack.together.ai to ask Llama 3.2 questions about an image." "Below is a complete example of to ask Llama 3.2 questions about an image."
] ]
}, },
{ {
@ -4106,14 +4210,12 @@
"id": "82e381ec", "id": "82e381ec",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### 4.1 Setup and helpers\n", "### 4.1 Setup and helpers\n"
"\n",
"Below we install the Llama Stack client 0.1, download the example image, define two image helpers, and set Llama Stack Together server URL and Llama 3.2 model name.\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 1,
"id": "44e05e16", "id": "44e05e16",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -4123,7 +4225,7 @@
"text": [ "text": [
" % Total % Received % Xferd Average Speed Time Time Time Current\n", " % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n", " Dload Upload Total Spent Left Speed\n",
"100 275k 100 275k 0 0 780k 0 --:--:-- --:--:-- --:--:-- 780k\n" "100 275k 100 275k 0 0 905k 0 --:--:-- --:--:-- --:--:-- 906k\n"
] ]
} }
], ],
@ -4133,32 +4235,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 20,
"id": "469750f7",
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_SKIP\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def display_image(path):\n",
" img = Image.open(path)\n",
" plt.imshow(img)\n",
" plt.axis('off')\n",
" plt.show()\n",
"\n",
"display_image(\"Llama_Repo.jpeg\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a2c1e1c2", "id": "a2c1e1c2",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import base64\n", "import base64\n",
"vision_model_id = \"meta-llama/Llama-3.2-11B-Vision-Instruct\"\n",
"\n", "\n",
"def encode_image(image_path):\n", "def encode_image(image_path):\n",
" with open(image_path, \"rb\") as image_file:\n", " with open(image_path, \"rb\") as image_file:\n",
@ -4167,19 +4250,6 @@
" return base64_url" " return base64_url"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "c565f99e",
"metadata": {},
"outputs": [],
"source": [
"from llama_stack_client import LlamaStackClient\n",
"\n",
"LLAMA_STACK_API_TOGETHER_URL=\"https://llama-stack.together.ai\"\n",
"LLAMA32_11B_INSTRUCT = \"meta-llama/Llama-3.2-11B-Vision-Instruct\""
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "7737cd41", "id": "7737cd41",
@ -4192,55 +4262,44 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 21,
"id": "d7914894", "id": "d7914894",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There are three llamas in the image. The llama in the middle is purple, the llama on the left is white, and the llama on the right is also white, but it is wearing a blue party hat. Therefore, there are two different colors of llama in the image: purple and white.\n"
]
}
],
"source": [ "source": [
"from llama_stack_client.lib.inference.event_logger import EventLogger\n", "response = client.inference.chat_completion(\n",
"\n", " messages=[\n",
"async def run_main(image_path: str, prompt):\n", " {\n",
" client = LlamaStackClient(\n", " \"role\": \"user\",\n",
" base_url=LLAMA_STACK_API_TOGETHER_URL,\n", " \"content\": [\n",
" )\n", " {\n",
"\n", " \"type\": \"image\",\n",
" message = {\n", " \"image\": {\n",
" \"role\": \"user\",\n", " \"url\": {\n",
" \"content\": [\n", " \"uri\": encode_image(\"Llama_Repo.jpeg\")\n",
" {\n", " }\n",
" \"type\": \"image\",\n", " }\n",
" \"image\": {\n", " },\n",
" \"url\": {\n", " {\n",
" \"uri\": encode_image(image_path)\n", " \"type\": \"text\",\n",
" }\n", " \"text\": \"How many different colors are those llamas? What are those colors?\",\n",
" }\n", " }\n",
" },\n", " ]\n",
" {\n", " }\n",
" \"type\": \"text\",\n", " ],\n",
" \"text\": prompt,\n", " model_id=vision_model_id,\n",
" }\n", " stream=False,\n",
" ]\n", ")\n",
" }\n",
"\n", "\n",
" response = client.inference.chat_completion(\n", "print(response.completion_message.content)"
" messages=[message],\n",
" model_id=LLAMA32_11B_INSTRUCT,\n",
" stream=False,\n",
" )\n",
"\n",
" print(response.completion_message.content.lower().strip())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4ee09b97",
"metadata": {},
"outputs": [],
"source": [
"await run_main(\"Llama_Repo.jpeg\",\n",
" \"How many different colors are those llamas?\\\n",
" What are those colors?\")"
] ]
}, },
{ {
@ -4255,68 +4314,60 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 19,
"id": "f9a83275", "id": "f9a83275",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33minference> \u001b[0m\u001b[33mThere\u001b[0m\u001b[33m are\u001b[0m\u001b[33m three\u001b[0m\u001b[33m different\u001b[0m\u001b[33m colors\u001b[0m\u001b[33m of\u001b[0m\u001b[33m ll\u001b[0m\u001b[33mamas\u001b[0m\u001b[33m in\u001b[0m\u001b[33m the\u001b[0m\u001b[33m image\u001b[0m\u001b[33m.\u001b[0m\u001b[33m The\u001b[0m\u001b[33m first\u001b[0m\u001b[33m llama\u001b[0m\u001b[33m on\u001b[0m\u001b[33m the\u001b[0m\u001b[33m left\u001b[0m\u001b[33m is\u001b[0m\u001b[33m white\u001b[0m\u001b[33m,\u001b[0m\u001b[33m the\u001b[0m\u001b[33m second\u001b[0m\u001b[33m llama\u001b[0m\u001b[33m in\u001b[0m\u001b[33m the\u001b[0m\u001b[33m middle\u001b[0m\u001b[33m is\u001b[0m\u001b[33m purple\u001b[0m\u001b[33m,\u001b[0m\u001b[33m and\u001b[0m\u001b[33m the\u001b[0m\u001b[33m third\u001b[0m\u001b[33m llama\u001b[0m\u001b[33m on\u001b[0m\u001b[33m the\u001b[0m\u001b[33m right\u001b[0m\u001b[33m is\u001b[0m\u001b[33m white\u001b[0m\u001b[33m with\u001b[0m\u001b[33m a\u001b[0m\u001b[33m blue\u001b[0m\u001b[33m party\u001b[0m\u001b[33m hat\u001b[0m\u001b[33m.\u001b[0m\u001b[97m\u001b[0m\n",
"\u001b[30m\u001b[0m"
]
}
],
"source": [ "source": [
"from llama_stack_client.lib.agents.agent import Agent\n", "agent = Agent(\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", " client, \n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n", " model=vision_model_id,\n",
" instructions=\"You are a helpful assistant\",\n",
")\n",
"session_id = agent.create_session(\"test-session\")\n",
"\n", "\n",
"async def run_main(image_path, prompt):\n", "response = agent.create_turn(\n",
" base64_image = encode_image(image_path)\n", " messages=[{\n",
"\n", " \"role\": \"user\",\n",
" client = LlamaStackClient(\n", " \"content\": [\n",
" base_url=LLAMA_STACK_API_TOGETHER_URL,\n", " {\n",
" )\n", " \"type\": \"image\",\n",
"\n", " \"image\": {\n",
" agent_config = AgentConfig(\n", " \"url\": {\n",
" model=LLAMA32_11B_INSTRUCT,\n", " \"uri\": encode_image(\"Llama_Repo.jpeg\")\n",
" instructions=\"You are a helpful assistant\",\n", " }\n",
" enable_session_persistence=False,\n",
" toolgroups=[],\n",
" )\n",
"\n",
" agent = Agent(client, agent_config)\n",
" session_id = agent.create_session(\"test-session\")\n",
"\n",
" response = agent.create_turn(\n",
" messages=[{\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\n",
" \"type\": \"image\",\n",
" \"image\": {\n",
" \"url\": {\n",
" \"uri\": encode_image(image_path)\n",
" }\n",
" }\n",
" },\n",
" {\n",
" \"type\": \"text\",\n",
" \"text\": prompt,\n",
" }\n", " }\n",
" ]\n", " },\n",
" }],\n", " {\n",
" session_id=session_id,\n", " \"type\": \"text\",\n",
" )\n", " \"text\": \"How many different colors are those llamas? What are those colors?\",\n",
" }\n",
" ]\n",
" }],\n",
" session_id=session_id,\n",
")\n",
"\n", "\n",
" for log in EventLogger().log(response):\n", "for log in EventLogger().log(response):\n",
" log.print()" " log.print()\n",
" "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "15d0098b", "id": "f3352379",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": []
"await run_main(\"Llama_Repo.jpeg\",\n",
" \"How many different colors are those llamas?\\\n",
" What are those colors?\")"
]
} }
], ],
"metadata": { "metadata": {

View file

@ -3675,7 +3675,7 @@
" benchmark_id=\"llama3.2-3B-instruct:tax_eval\",\n", " benchmark_id=\"llama3.2-3B-instruct:tax_eval\",\n",
" input_rows=eval_rows.rows,\n", " input_rows=eval_rows.rows,\n",
" scoring_functions=[\"braintrust::answer-similarity\"],\n", " scoring_functions=[\"braintrust::answer-similarity\"],\n",
" task_config={\n", " benchmark_config={\n",
" \"type\": \"benchmark\",\n", " \"type\": \"benchmark\",\n",
" \"eval_candidate\": {\n", " \"eval_candidate\": {\n",
" \"type\": \"model\",\n", " \"type\": \"model\",\n",
@ -6383,7 +6383,7 @@
" benchmark_id=\"Llama-3.2-3B-Instruct-sft-0:tax_eval\",\n", " benchmark_id=\"Llama-3.2-3B-Instruct-sft-0:tax_eval\",\n",
" input_rows=eval_rows.rows,\n", " input_rows=eval_rows.rows,\n",
" scoring_functions=[\"braintrust::answer-similarity\"],\n", " scoring_functions=[\"braintrust::answer-similarity\"],\n",
" task_config={\n", " benchmark_config={\n",
" \"type\": \"benchmark\",\n", " \"type\": \"benchmark\",\n",
" \"eval_candidate\": {\n", " \"eval_candidate\": {\n",
" \"type\": \"model\",\n", " \"type\": \"model\",\n",

File diff suppressed because it is too large Load diff

View file

@ -781,7 +781,7 @@
" benchmark_id=\"meta-reference::mmmu\",\n", " benchmark_id=\"meta-reference::mmmu\",\n",
" input_rows=eval_rows,\n", " input_rows=eval_rows,\n",
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n", " scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
" task_config={\n", " benchmark_config={\n",
" \"type\": \"benchmark\",\n", " \"type\": \"benchmark\",\n",
" \"eval_candidate\": {\n", " \"eval_candidate\": {\n",
" \"type\": \"model\",\n", " \"type\": \"model\",\n",
@ -826,10 +826,9 @@
"_ = client.datasets.register(\n", "_ = client.datasets.register(\n",
" dataset_id=simpleqa_dataset_id,\n", " dataset_id=simpleqa_dataset_id,\n",
" provider_id=\"huggingface\",\n", " provider_id=\"huggingface\",\n",
" url={\"uri\": \"https://huggingface.co/datasets/llamastack/evals\"},\n", " url={\"uri\": \"https://huggingface.co/datasets/llamastack/simpleqa\"},\n",
" metadata={\n", " metadata={\n",
" \"path\": \"llamastack/evals\",\n", " \"path\": \"llamastack/simpleqa\",\n",
" \"name\": \"evals__simpleqa\",\n",
" \"split\": \"train\",\n", " \"split\": \"train\",\n",
" },\n", " },\n",
" dataset_schema={\n", " dataset_schema={\n",
@ -960,7 +959,7 @@
" benchmark_id=\"meta-reference::simpleqa\",\n", " benchmark_id=\"meta-reference::simpleqa\",\n",
" input_rows=eval_rows.rows,\n", " input_rows=eval_rows.rows,\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n", " scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
" task_config={\n", " benchmark_config={\n",
" \"type\": \"benchmark\",\n", " \"type\": \"benchmark\",\n",
" \"eval_candidate\": {\n", " \"eval_candidate\": {\n",
" \"type\": \"model\",\n", " \"type\": \"model\",\n",
@ -1109,7 +1108,7 @@
" benchmark_id=\"meta-reference::simpleqa\",\n", " benchmark_id=\"meta-reference::simpleqa\",\n",
" input_rows=eval_rows.rows,\n", " input_rows=eval_rows.rows,\n",
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n", " scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
" task_config={\n", " benchmark_config={\n",
" \"type\": \"benchmark\",\n", " \"type\": \"benchmark\",\n",
" \"eval_candidate\": {\n", " \"eval_candidate\": {\n",
" \"type\": \"agent\",\n", " \"type\": \"agent\",\n",

View file

@ -3,7 +3,7 @@ The RFC Specification (OpenAPI format) is generated from the set of API endpoint
Please install the following packages before running the script: Please install the following packages before running the script:
``` ```
pip install fire PyYAML llama-models pip install fire PyYAML
``` ```
Then simply run `sh run_openapi_generator.sh` Then simply run `sh run_openapi_generator.sh`

View file

@ -55,6 +55,7 @@ def main(output_dir: str):
a set of endpoints and their corresponding interfaces that are tailored to a set of endpoints and their corresponding interfaces that are tailored to
best leverage Llama Models.""", best leverage Llama Models.""",
), ),
include_standard_error_responses=True,
), ),
) )

View file

@ -10,6 +10,7 @@ import typing
from dataclasses import make_dataclass from dataclasses import make_dataclass
from typing import Any, Dict, Set, Union from typing import Any, Dict, Set, Union
from llama_stack.apis.datatypes import Error
from llama_stack.strong_typing.core import JsonType from llama_stack.strong_typing.core import JsonType
from llama_stack.strong_typing.docstring import Docstring, parse_type from llama_stack.strong_typing.docstring import Docstring, parse_type
from llama_stack.strong_typing.inspection import ( from llama_stack.strong_typing.inspection import (
@ -434,6 +435,75 @@ class Generator:
) )
self.schema_builder = SchemaBuilder(schema_generator) self.schema_builder = SchemaBuilder(schema_generator)
self.responses = {} self.responses = {}
# Create standard error responses
self._create_standard_error_responses()
def _create_standard_error_responses(self) -> None:
"""
Creates standard error responses that can be reused across operations.
These will be added to the components.responses section of the OpenAPI document.
"""
# Get the Error schema
error_schema = self.schema_builder.classdef_to_ref(Error)
# Create standard error responses
self.responses["BadRequest400"] = Response(
description="The request was invalid or malformed",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 400,
"title": "Bad Request",
"detail": "The request was invalid or malformed",
}
)
}
)
self.responses["TooManyRequests429"] = Response(
description="The client has sent too many requests in a given amount of time",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 429,
"title": "Too Many Requests",
"detail": "You have exceeded the rate limit. Please try again later.",
}
)
}
)
self.responses["InternalServerError500"] = Response(
description="The server encountered an unexpected error",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 500,
"title": "Internal Server Error",
"detail": "An unexpected error occurred. Our team has been notified.",
}
)
}
)
# Add a default error response for any unhandled error cases
self.responses["DefaultError"] = Response(
description="An unexpected error occurred",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 0,
"title": "Error",
"detail": "An unexpected error occurred",
}
)
}
)
def _build_type_tag(self, ref: str, schema: Schema) -> Tag: def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
# Don't include schema definition in the tag description because for one, # Don't include schema definition in the tag description because for one,
@ -649,6 +719,18 @@ class Generator:
responses.update(response_builder.build_response(response_options)) responses.update(response_builder.build_response(response_options))
assert len(responses.keys()) > 0, f"No responses found for {op.name}" assert len(responses.keys()) > 0, f"No responses found for {op.name}"
# Add standard error response references
if self.options.include_standard_error_responses:
if "400" not in responses:
responses["400"] = ResponseRef("BadRequest400")
if "429" not in responses:
responses["429"] = ResponseRef("TooManyRequests429")
if "500" not in responses:
responses["500"] = ResponseRef("InternalServerError500")
if "default" not in responses:
responses["default"] = ResponseRef("DefaultError")
if op.event_type is not None: if op.event_type is not None:
builder = ContentBuilder(self.schema_builder) builder = ContentBuilder(self.schema_builder)
callbacks = { callbacks = {

View file

@ -35,6 +35,7 @@ class Options:
:param error_wrapper: True if errors are encapsulated in an error object wrapper. :param error_wrapper: True if errors are encapsulated in an error object wrapper.
:param property_description_fun: Custom transformation function to apply to class property documentation strings. :param property_description_fun: Custom transformation function to apply to class property documentation strings.
:param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types. :param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types.
:param include_standard_error_responses: Whether to include standard error responses (400, 429, 500, 503) in all operations.
""" """
server: Server server: Server
@ -52,6 +53,7 @@ class Options:
error_wrapper: bool = False error_wrapper: bool = False
property_description_fun: Optional[Callable[[type, str, str], str]] = None property_description_fun: Optional[Callable[[type, str, str], str]] = None
captions: Optional[Dict[str, str]] = None captions: Optional[Dict[str, str]] = None
include_standard_error_responses: bool = True
default_captions: ClassVar[Dict[str, str]] = { default_captions: ClassVar[Dict[str, str]] = {
"Operations": "Operations", "Operations": "Operations",

View file

@ -28,6 +28,5 @@ if [ ${#missing_packages[@]} -ne 0 ]; then
fi fi
stack_dir=$(dirname $(dirname $THIS_DIR)) stack_dir=$(dirname $(dirname $THIS_DIR))
models_dir=$(dirname $stack_dir)/llama-models PYTHONPATH=$PYTHONPATH:$stack_dir \
PYTHONPATH=$PYTHONPATH:$stack_dir:$models_dir \
python -m docs.openapi_generator.generate $(dirname $THIS_DIR)/_static python -m docs.openapi_generator.generate $(dirname $THIS_DIR)/_static

View file

@ -11,3 +11,4 @@ sphinxcontrib-openapi
sphinxcontrib-redoc sphinxcontrib-redoc
sphinxcontrib-mermaid sphinxcontrib-mermaid
sphinxcontrib-video sphinxcontrib-video
tomli

View file

@ -0,0 +1,89 @@
# Llama Stack Agent Framework
The Llama Stack agent framework is built on a modular architecture that allows for flexible and powerful AI applications. This document explains the key components and how they work together.
## Core Concepts
### 1. Agent Configuration
Agents are configured using the `AgentConfig` class, which includes:
- **Model**: The underlying LLM to power the agent
- **Instructions**: System prompt that defines the agent's behavior
- **Tools**: Capabilities the agent can use to interact with external systems
- **Safety Shields**: Guardrails to ensure responsible AI behavior
```python
from llama_stack_client.lib.agents.agent import Agent
# Create the agent
agent = Agent(
llama_stack_client,
model="meta-llama/Llama-3-70b-chat",
instructions="You are a helpful assistant that can use tools to answer questions.",
tools=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
)
```
### 2. Sessions
Agents maintain state through sessions, which represent a conversation thread:
```python
# Create a session
session_id = agent.create_session(session_name="My conversation")
```
### 3. Turns
Each interaction with an agent is called a "turn" and consists of:
- **Input Messages**: What the user sends to the agent
- **Steps**: The agent's internal processing (inference, tool execution, etc.)
- **Output Message**: The agent's response
```python
from llama_stack_client.lib.agents.event_logger import EventLogger
# Create a turn with streaming response
turn_response = agent.create_turn(
session_id=session_id,
messages=[{"role": "user", "content": "Tell me about Llama models"}],
)
for log in EventLogger().log(turn_response):
log.print()
```
### Non-Streaming
```python
from rich.pretty import pprint
# Non-streaming API
response = agent.create_turn(
session_id=session_id,
messages=[{"role": "user", "content": "Tell me about Llama models"}],
stream=False,
)
print("Inputs:")
pprint(response.input_messages)
print("Output:")
pprint(response.output_message.content)
print("Steps:")
pprint(response.steps)
```
### 4. Steps
Each turn consists of multiple steps that represent the agent's thought process:
- **Inference Steps**: The agent generating text responses
- **Tool Execution Steps**: The agent using tools to gather information
- **Shield Call Steps**: Safety checks being performed
## Agent Execution Loop
Refer to the [Agent Execution Loop](agent_execution_loop) for more details on what happens within an agent turn.

View file

@ -13,7 +13,7 @@ Each agent turn follows these key steps:
3. **Inference Loop**: The agent enters its main execution loop: 3. **Inference Loop**: The agent enters its main execution loop:
- The LLM receives a user prompt (with previous tool outputs) - The LLM receives a user prompt (with previous tool outputs)
- The LLM generates a response, potentially with tool calls - The LLM generates a response, potentially with [tool calls](tools)
- If tool calls are present: - If tool calls are present:
- Tool inputs are safety-checked - Tool inputs are safety-checked
- Tools are executed (e.g., web search, code execution) - Tools are executed (e.g., web search, code execution)
@ -67,20 +67,28 @@ sequenceDiagram
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution: Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
```python ```python
from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from rich.pretty import pprint
agent_config = AgentConfig( # Replace host and port
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
agent = Agent(
client,
# Check with `llama-stack-client models list`
model="Llama3.2-3B-Instruct", model="Llama3.2-3B-Instruct",
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
# Enable both RAG and tool usage # Enable both RAG and tool usage
toolgroups=[ tools=[
{ {
"name": "builtin::rag/knowledge_search", "name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": ["my_docs"]}, "args": {"vector_db_ids": ["my_docs"]},
}, },
"builtin::code_interpreter", "builtin::code_interpreter",
], ],
# Configure safety # Configure safety (optional)
input_shields=["llama_guard"], input_shields=["llama_guard"],
output_shields=["llama_guard"], output_shields=["llama_guard"],
# Control the inference loop # Control the inference loop
@ -90,14 +98,12 @@ agent_config = AgentConfig(
"max_tokens": 2048, "max_tokens": 2048,
}, },
) )
agent = Agent(client, agent_config)
session_id = agent.create_session("monitored_session") session_id = agent.create_session("monitored_session")
# Stream the agent's execution steps # Stream the agent's execution steps
response = agent.create_turn( response = agent.create_turn(
messages=[{"role": "user", "content": "Analyze this code and run it"}], messages=[{"role": "user", "content": "Analyze this code and run it"}],
attachments=[ documents=[
{ {
"content": "https://raw.githubusercontent.com/example/code.py", "content": "https://raw.githubusercontent.com/example/code.py",
"mime_type": "text/plain", "mime_type": "text/plain",
@ -108,14 +114,21 @@ response = agent.create_turn(
# Monitor each step of execution # Monitor each step of execution
for log in EventLogger().log(response): for log in EventLogger().log(response):
if log.event.step_type == "memory_retrieval": log.print()
print("Retrieved context:", log.event.retrieved_context)
elif log.event.step_type == "inference": # Using non-streaming API, the response contains input, steps, and output.
print("LLM output:", log.event.model_response) response = agent.create_turn(
elif log.event.step_type == "tool_execution": messages=[{"role": "user", "content": "Analyze this code and run it"}],
print("Tool call:", log.event.tool_call) documents=[
print("Tool response:", log.event.tool_response) {
elif log.event.step_type == "shield_call": "content": "https://raw.githubusercontent.com/example/code.py",
if log.event.violation: "mime_type": "text/plain",
print("Safety violation:", log.event.violation) }
],
session_id=session_id,
)
pprint(f"Input: {response.input_messages}")
pprint(f"Output: {response.output_message.content}")
pprint(f"Steps: {response.steps}")
``` ```

View file

@ -1,170 +1,124 @@
# Evals # Evaluations
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing) The Llama Stack provides a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
- `/datasetio` + `/datasets` API
- `/scoring` + `/scoring_functions` API
- `/eval` + `/benchmarks` API
Llama Stack provides the building blocks needed to run benchmark and application evaluations. This guide will walk you through how to use these components to run open benchmark evaluations. Visit our [Evaluation Concepts](../concepts/evaluation_concepts.md) guide for more details on how evaluations work in Llama Stack, and our [Evaluation Reference](../references/evals_reference/index.md) guide for a comprehensive reference on the APIs.
### 1. Open Benchmark Model Evaluation
This first example walks you through how to evaluate a model candidate served by Llama Stack on open benchmarks. We will use the following benchmark: This guides walks you through the process of evaluating an LLM application built using Llama Stack. Checkout the [Evaluation Reference](../references/evals_reference/index.md) guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for benchmark and application use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI): Benchmark designed to evaluate multimodal models.
- [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions.
#### 1.1 Running MMMU
- We will use a pre-processed MMMU dataset from [llamastack/mmmu](https://huggingface.co/datasets/llamastack/mmmu). The preprocessing code is shown in in this [Github Gist](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840). The dataset is obtained by transforming the original [MMMU/MMMU](https://huggingface.co/datasets/MMMU/MMMU) dataset into correct format by `inference/chat-completion` API.
## Application Evaluation
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb)
Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets.
In this example, we will show you how to:
1. Build an Agent with Llama Stack
2. Query the agent's sessions, turns, and steps
3. Evaluate the results.
##### Building a Search Agent
```python ```python
import datasets from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev") agent = Agent(
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"]) client,
eval_rows = ds.to_pandas().to_dict(orient="records") model="meta-llama/Llama-3.3-70B-Instruct",
``` instructions="You are a helpful assistant. Use search tool to answer the questions. ",
tools=["builtin::websearch"],
- Next, we will run evaluation on an model candidate, we will need to:
- Define a system prompt
- Define an EvalCandidate
- Run evaluate on the dataset
```python
SYSTEM_PROMPT_TEMPLATE = """
You are an expert in Agriculture whose job is to answer questions from the user using images.
First, reason about the correct answer.
Then write the answer in the following format where X is exactly one of A,B,C,D:
Answer: X
Make sure X is one of A,B,C,D.
If you are uncertain of the correct answer, guess the most likely one.
"""
system_message = {
"role": "system",
"content": SYSTEM_PROMPT_TEMPLATE,
}
client.benchmarks.register(
benchmark_id="meta-reference::mmmu",
dataset_id=f"mmmu-{subset}-{split}",
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
) )
user_prompts = [
"Which teams played in the NBA western conference finals of 2024. Search the web for the answer.",
"In which episode and season of South Park does Bill Cosby (BSM-471) first appear? Give me the number and title. Search the web for the answer.",
"What is the British-American kickboxer Andrew Tate's kickboxing name? Search the web for the answer.",
]
response = client.eval.evaluate_rows( session_id = agent.create_session("test-session")
benchmark_id="meta-reference::mmmu",
input_rows=eval_rows,
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
task_config={
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"sampling_params": {
"strategy": {
"type": "greedy",
},
"max_tokens": 4096,
"repeat_penalty": 1.0,
},
"system_message": system_message,
},
},
)
```
#### 1.2. Running SimpleQA for prompt in user_prompts:
- We will use a pre-processed SimpleQA dataset from [llamastack/evals](https://huggingface.co/datasets/llamastack/evals/viewer/evals__simpleqa) which is obtained by transforming the input query into correct format accepted by `inference/chat-completion` API. response = agent.create_turn(
- Since we will be using this same dataset in our next example for Agentic evaluation, we will register it using the `/datasets` API, and interact with it through `/datasetio` API. messages=[
{
"role": "user",
"content": prompt,
}
],
session_id=session_id,
)
```python for log in EventLogger().log(response):
simpleqa_dataset_id = "huggingface::simpleqa" log.print()
_ = client.datasets.register(
dataset_id=simpleqa_dataset_id,
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/evals"},
metadata={
"path": "llamastack/evals",
"name": "evals__simpleqa",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"},
},
)
eval_rows = client.datasetio.get_rows_paginated(
dataset_id=simpleqa_dataset_id,
rows_in_page=5,
)
```
```python
client.benchmarks.register(
benchmark_id="meta-reference::simpleqa",
dataset_id=simpleqa_dataset_id,
scoring_functions=["llm-as-judge::405b-simpleqa"],
)
response = client.eval.evaluate_rows(
benchmark_id="meta-reference::simpleqa",
input_rows=eval_rows.rows,
scoring_functions=["llm-as-judge::405b-simpleqa"],
task_config={
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"sampling_params": {
"strategy": {
"type": "greedy",
},
"max_tokens": 4096,
"repeat_penalty": 1.0,
},
},
},
)
``` ```
### 2. Agentic Evaluation ##### Query Agent Execution Steps
- In this example, we will demonstrate how to evaluate a agent candidate served by Llama Stack via `/agent` API.
- We will continue to use the SimpleQA dataset we used in previous example. Now, let's look deeper into the agent's execution steps and see if how well our agent performs.
- Instead of running evaluation on model, we will run the evaluation on a Search Agent with access to search tool. We will define our agent evaluation candidate through `AgentConfig`. ```python
# query the agents session
from rich.pretty import pprint
session_response = client.agents.session.retrieve(
session_id=session_id,
agent_id=agent.agent_id,
)
pprint(session_response)
```
As a sanity check, we will first check if all user prompts is followed by a tool call to `brave_search`.
```python
num_tool_call = 0
for turn in session_response.turns:
for step in turn.steps:
if (
step.step_type == "tool_execution"
and step.tool_calls[0].tool_name == "brave_search"
):
num_tool_call += 1
print(
f"{num_tool_call}/{len(session_response.turns)} user prompts are followed by a tool call to `brave_search`"
)
```
##### Evaluate Agent Responses
Now, we want to evaluate the agent's responses to the user prompts.
1. First, we will process the agent's execution history into a list of rows that can be used for evaluation.
2. Next, we will label the rows with the expected answer.
3. Finally, we will use the `/scoring` API to score the agent's responses.
```python ```python
agent_config = { eval_rows = []
"model": "meta-llama/Llama-3.1-405B-Instruct",
"instructions": "You are a helpful assistant", expected_answers = [
"sampling_params": { "Dallas Mavericks and the Minnesota Timberwolves",
"strategy": { "Season 4, Episode 12",
"type": "greedy", "King Cobra",
}, ]
},
"tools": [ for i, turn in enumerate(session_response.turns):
eval_rows.append(
{ {
"type": "brave_search", "input_query": turn.input_messages[0].content,
"engine": "tavily", "generated_answer": turn.output_message.content,
"api_key": userdata.get("TAVILY_SEARCH_API_KEY"), "expected_answer": expected_answers[i],
} }
], )
"tool_choice": "auto",
"tool_prompt_format": "json",
"input_shields": [],
"output_shields": [],
"enable_session_persistence": False,
}
response = client.eval.evaluate_rows( pprint(eval_rows)
benchmark_id="meta-reference::simpleqa",
input_rows=eval_rows.rows, scoring_params = {
scoring_functions=["llm-as-judge::405b-simpleqa"], "basic::subset_of": None,
task_config={ }
"type": "benchmark", scoring_response = client.scoring.score(
"eval_candidate": { input_rows=eval_rows, scoring_functions=scoring_params
"type": "agent",
"config": agent_config,
},
},
) )
pprint(scoring_response)
``` ```

View file

@ -1,30 +0,0 @@
## Testing & Evaluation
Llama Stack provides built-in tools for evaluating your applications:
1. **Benchmarking**: Test against standard datasets
2. **Application Evaluation**: Score your application's outputs
3. **Custom Metrics**: Define your own evaluation criteria
Here's how to set up basic evaluation:
```python
# Create an evaluation task
response = client.benchmarks.register(
benchmark_id="my_eval",
dataset_id="my_dataset",
scoring_functions=["accuracy", "relevance"],
)
# Run evaluation
job = client.eval.run_eval(
benchmark_id="my_eval",
task_config={
"type": "app",
"eval_candidate": {"type": "agent", "config": agent_config},
},
)
# Get results
result = client.eval.job_result(benchmark_id="my_eval", job_id=job.job_id)
```

View file

@ -8,22 +8,24 @@ The best way to get started is to look at this notebook which walks through the
Here are some key topics that will help you build effective agents: Here are some key topics that will help you build effective agents:
- **[Agent Execution Loop](agent_execution_loop)** - **[Agent](agent)**: Understand the components and design patterns of the Llama Stack agent framework.
- **[RAG](rag)** - **[Agent Execution Loop](agent_execution_loop)**: Understand how agents process information, make decisions, and execute actions in a continuous loop.
- **[Safety](safety)** - **[RAG (Retrieval-Augmented Generation)](rag)**: Learn how to enhance your agents with external knowledge through retrieval mechanisms.
- **[Tools](tools)** - **[Tools](tools)**: Extend your agents' capabilities by integrating with external tools and APIs.
- **[Telemetry](telemetry)** - **[Evals](evals)**: Evaluate your agents' effectiveness and identify areas for improvement.
- **[Evals](evals)** - **[Telemetry](telemetry)**: Monitor and analyze your agents' performance and behavior.
- **[Safety](safety)**: Implement guardrails and safety measures to ensure responsible AI behavior.
```{toctree} ```{toctree}
:hidden: :hidden:
:maxdepth: 1 :maxdepth: 1
agent
agent_execution_loop agent_execution_loop
rag rag
safety
tools tools
telemetry telemetry
evals evals
advanced_agent_patterns
safety
``` ```

View file

@ -1,8 +1,8 @@
## Using "Memory" or Retrieval Augmented Generation (RAG) ## Using Retrieval Augmented Generation (RAG)
Memory enables your applications to reference and recall information from previous interactions or external documents. RAG enables your applications to reference and recall information from previous interactions or external documents.
Llama Stack organizes the memory APIs into three layers: Llama Stack organizes the APIs that enable RAG into three layers:
- the lowermost APIs deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon.) - the lowermost APIs deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon.)
- next is the "Rag Tool", a first-class tool as part of the Tools API that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly. - next is the "Rag Tool", a first-class tool as part of the Tools API that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly.
- finally, it all comes together with the top-level "Agents" API that allows you to create agents that can use the tools to answer questions, perform tasks, and more. - finally, it all comes together with the top-level "Agents" API that allows you to create agents that can use the tools to answer questions, perform tasks, and more.
@ -20,6 +20,11 @@ We may add more storage types like Graph IO in the future.
Here's how to set up a vector database for RAG: Here's how to set up a vector database for RAG:
```python ```python
# Create http client
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}")
# Register a vector db # Register a vector db
vector_db_id = "my_documents" vector_db_id = "my_documents"
response = client.vector_dbs.register( response = client.vector_dbs.register(
@ -81,15 +86,14 @@ results = client.tool_runtime.rag_tool.query(
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example: One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
```python ```python
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
# Configure agent with memory # Create agent with memory
agent_config = AgentConfig( agent = Agent(
model="meta-llama/Llama-3.2-3B-Instruct", client,
model="meta-llama/Llama-3.3-70B-Instruct",
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
enable_session_persistence=False, tools=[
toolgroups=[
{ {
"name": "builtin::rag/knowledge_search", "name": "builtin::rag/knowledge_search",
"args": { "args": {
@ -98,10 +102,21 @@ agent_config = AgentConfig(
} }
], ],
) )
agent = Agent(client, agent_config)
session_id = agent.create_session("rag_session") session_id = agent.create_session("rag_session")
# Ask questions about documents in the vector db, and the agent will query the db to answer the question.
response = agent.create_turn(
messages=[{"role": "user", "content": "How to optimize memory in PyTorch?"}],
session_id=session_id,
)
```
> **NOTE:** the `instructions` field in the `AgentConfig` can be used to guide the agent's behavior. It is important to experiment with different instructions to see what works best for your use case.
You can also pass documents along with the user's message and ask questions about them.
```python
# Initial document ingestion # Initial document ingestion
response = agent.create_turn( response = agent.create_turn(
messages=[ messages=[
@ -109,7 +124,7 @@ response = agent.create_turn(
], ],
documents=[ documents=[
{ {
"content": "https://raw.githubusercontent.com/example/doc.rst", "content": "https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/memory_optimizations.rst",
"mime_type": "text/plain", "mime_type": "text/plain",
} }
], ],
@ -123,6 +138,14 @@ response = agent.create_turn(
) )
``` ```
You can print the response with below.
```python
from llama_stack_client.lib.agents.event_logger import EventLogger
for log in EventLogger().log(response):
log.print()
```
### Unregistering Vector DBs ### Unregistering Vector DBs
If you need to clean up and unregister vector databases, you can do so as follows: If you need to clean up and unregister vector databases, you can do so as follows:

View file

@ -5,7 +5,7 @@ An example of this would be a "db_access" tool group that contains tools for int
Tools are treated as any other resource in llama stack like models. You can register them, have providers for them etc. Tools are treated as any other resource in llama stack like models. You can register them, have providers for them etc.
When instatiating an agent, you can provide it a list of tool groups that it has access to. Agent gets the corresponding tool definitions for the specified tool groups and passes them along to the model. When instantiating an agent, you can provide it a list of tool groups that it has access to. Agent gets the corresponding tool definitions for the specified tool groups and passes them along to the model.
Refer to the [Building AI Applications](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) notebook for more examples on how to use tools. Refer to the [Building AI Applications](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb) notebook for more examples on how to use tools.
@ -60,7 +60,7 @@ Features:
- Disabled dangerous system operations - Disabled dangerous system operations
- Configurable execution timeouts - Configurable execution timeouts
> ⚠️ Important: The code interpreter tool can operate in a controlled enviroment locally or on Podman containers. To ensure proper functionality in containerised environments: > ⚠️ Important: The code interpreter tool can operate in a controlled environment locally or on Podman containers. To ensure proper functionality in containerized environments:
> - The container requires privileged access (e.g., --privileged). > - The container requires privileged access (e.g., --privileged).
> - Users without sufficient permissions may encounter permission errors. (`bwrap: Can't mount devpts on /newroot/dev/pts: Permission denied`) > - Users without sufficient permissions may encounter permission errors. (`bwrap: Can't mount devpts on /newroot/dev/pts: Permission denied`)
> - 🔒 Security Warning: Privileged mode grants elevated access and bypasses security restrictions. Use only in local, isolated, or controlled environments. > - 🔒 Security Warning: Privileged mode grants elevated access and bypasses security restrictions. Use only in local, isolated, or controlled environments.
@ -83,15 +83,15 @@ result = client.tool_runtime.invoke_tool(
) )
``` ```
#### Memory #### RAG
The Memory tool enables retrieval of context from various types of memory banks (vector, key-value, keyword, and graph). The RAG tool enables retrieval of context from various types of memory banks (vector, key-value, keyword, and graph).
```python ```python
# Register Memory tool group # Register Memory tool group
client.toolgroups.register( client.toolgroups.register(
toolgroup_id="builtin::memory", toolgroup_id="builtin::rag",
provider_id="memory", provider_id="faiss",
args={"max_chunks": 5, "max_tokens_in_context": 4096}, args={"max_chunks": 5, "max_tokens_in_context": 4096},
) )
``` ```
@ -102,7 +102,7 @@ Features:
- Context retrieval with token limits - Context retrieval with token limits
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, code interpreter and memory, that are provided by tavily-search, code-interpreter and memory providers. > **Note:** By default, llama stack run.yaml defines toolgroups for web search, code interpreter and rag, that are provided by tavily-search, code-interpreter and rag providers.
## Model Context Protocol (MCP) Tools ## Model Context Protocol (MCP) Tools
@ -125,50 +125,35 @@ MCP tools require:
- Tools are discovered dynamically from the endpoint - Tools are discovered dynamically from the endpoint
## Tools provided by the client ## Adding Custom Tools
These tools are registered along with the agent config and are specific to the agent for which they are registered. The main difference between these tools and the tools provided by the built-in providers is that the execution of these tools is handled by the client and the agent transfers the tool call to the client and waits for the result from the client. When you want to use tools other than the built-in tools, you can implement a python function and decorate it with `@client_tool`.
To define a custom tool, you need to use the `@client_tool` decorator.
```python
from llama_stack_client.lib.agents.client_tool import client_tool
# Example tool definition
@client_tool
def my_tool(input: int) -> int:
"""
Runs my awesome tool.
:param input: some int parameter
"""
return input * 2
```
> **NOTE:** We employ python docstrings to describe the tool and the parameters. It is important to document the tool and the parameters so that the model can use the tool correctly. It is recommended to experiment with different docstrings to see how they affect the model's behavior.
Once defined, simply pass the tool to the agent config. `Agent` will take care of the rest (calling the model with the tool definition, executing the tool, and returning the result to the model for the next iteration).
```python ```python
# Example agent config with client provided tools # Example agent config with client provided tools
config = AgentConfig( agent = Agent(client, ..., tools=[my_tool])
toolgroups=[
"builtin::websearch",
],
client_tools=[ToolDef(name="client_tool", description="Client provided tool")],
)
``` ```
Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools. Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools.
## Tool Structure
Each tool has the following components:
- `name`: Unique identifier for the tool
- `description`: Human-readable description of the tool's functionality
- `parameters`: List of parameters the tool accepts
- `name`: Parameter name
- `parameter_type`: Data type (string, number, etc.)
- `description`: Parameter description
- `required`: Whether the parameter is required (default: true)
- `default`: Default value if any
Example tool definition:
```python
{
"name": "web_search",
"description": "Search the web for information",
"parameters": [
{
"name": "query",
"parameter_type": "string",
"description": "The query to search for",
"required": True,
}
],
}
```
## Tool Invocation ## Tool Invocation
@ -201,10 +186,10 @@ group_tools = client.tools.list_tools(toolgroup_id="search_tools")
```python ```python
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.types.agent_create_params import AgentConfig
# Configure the AI agent with necessary parameters # Instantiate the AI agent with the given configuration
agent_config = AgentConfig( agent = Agent(
client,
name="code-interpreter", name="code-interpreter",
description="A code interpreter agent for executing Python code snippets", description="A code interpreter agent for executing Python code snippets",
instructions=""" instructions="""
@ -212,14 +197,10 @@ agent_config = AgentConfig(
Always show the generated code, never generate your own code, and never anticipate results. Always show the generated code, never generate your own code, and never anticipate results.
""", """,
model="meta-llama/Llama-3.2-3B-Instruct", model="meta-llama/Llama-3.2-3B-Instruct",
toolgroups=["builtin::code_interpreter"], tools=["builtin::code_interpreter"],
max_infer_iters=5, max_infer_iters=5,
enable_session_persistence=False,
) )
# Instantiate the AI agent with the given configuration
agent = Agent(client, agent_config)
# Start a session # Start a session
session_id = agent.create_session("tool_session") session_id = agent.create_session("tool_session")

View file

@ -24,17 +24,8 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
- Associated with `Benchmark` resource. - Associated with `Benchmark` resource.
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
![Eval Flow](../references/evals_reference/resources/eval-flow.png)
```{admonition} Note on Benchmark v.s. Application Evaluation
:class: tip
- **Benchmark Evaluation** is a well-defined eval-task consisting of `dataset` and `scoring_function`. The generation (inference or agent) will be done as part of evaluation.
- **Application Evaluation** assumes users already have app inputs & generated outputs. Evaluation will purely focus on scoring the generated outputs via scoring functions (e.g. LLM-as-judge).
```
## What's Next? ## What's Next?
- Check out our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing). - Check out our Colab notebook on working examples with running benchmark evaluations [here](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb#scrollTo=mxLCsP4MvFqP).
- Check out our [Building Applications - Evaluation](../building_applications/evals.md) guide for more details on how to use the Evaluation APIs to evaluate your applications.
- Check out our [Evaluation Reference](../references/evals_reference/index.md) for more details on the APIs. - Check out our [Evaluation Reference](../references/evals_reference/index.md) for more details on the APIs.

View file

@ -1,5 +1,13 @@
# Core Concepts # Core Concepts
```{toctree}
:maxdepth: 1
:hidden:
evaluation_concepts
```
Given Llama Stack's service-oriented philosophy, a few concepts and workflows arise which may not feel completely natural in the LLM landscape, especially if you are coming with a background in other frameworks. Given Llama Stack's service-oriented philosophy, a few concepts and workflows arise which may not feel completely natural in the LLM landscape, especially if you are coming with a background in other frameworks.
@ -26,7 +34,7 @@ We are working on adding a few more APIs to complete the application lifecycle.
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include: The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.), - LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.), - Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.) - Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
Providers come in two flavors: Providers come in two flavors:

View file

@ -13,13 +13,19 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
from docutils import nodes from docutils import nodes
import tomli # Import tomli for TOML parsing
from pathlib import Path from pathlib import Path
import requests
import json
# Read version from pyproject.toml # Read version from pyproject.toml
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f: with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
pyproject = tomli.load(f) pypi_url = "https://pypi.org/pypi/llama-stack/json"
llama_stack_version = pyproject["project"]["version"] version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"]
print(f"{version_tag=}")
# generate the full link including text and url here
llama_stack_version_url = f"https://github.com/meta-llama/llama-stack/releases/tag/v{version_tag}"
llama_stack_version_link = f"<a href='{llama_stack_version_url}'>release notes</a>"
project = "llama-stack" project = "llama-stack"
copyright = "2025, Meta" copyright = "2025, Meta"
@ -73,7 +79,8 @@ myst_enable_extensions = [
myst_substitutions = { myst_substitutions = {
"docker_hub": "https://hub.docker.com/repository/docker/llamastack", "docker_hub": "https://hub.docker.com/repository/docker/llamastack",
"llama_stack_version": llama_stack_version, "llama_stack_version": version_tag,
"llama_stack_version_link": llama_stack_version_link,
} }
suppress_warnings = ['myst.header'] suppress_warnings = ['myst.header']

View file

@ -17,25 +17,31 @@ Here are some example PRs to help you get started:
## Testing the Provider ## Testing the Provider
Before running tests, you must have required dependencies installed. This depends on the providers or distributions you are testing. For example, if you are testing the `together` distribution, you should install dependencies via `llama stack build --template together`.
### 1. Integration Testing ### 1. Integration Testing
- Create integration tests that use real provider instances and configurations
- For remote services, test actual API interactions
- Avoid mocking at the provider level since adapter layers tend to be thin
- Reference examples in {repopath}`tests/client-sdk`
### 2. Unit Testing (Optional) Integration tests are located in {repopath}`tests/integration`. These tests use the python client-SDK APIs (from the `llama_stack_client` package) to test functionality. Since these tests use client APIs, they can be run either by pointing to an instance of the Llama Stack server or "inline" by using `LlamaStackAsLibraryClient`.
- Add unit tests for provider-specific functionality
- See examples in {repopath}`llama_stack/providers/tests/inference/test_text_inference.py` Consult {repopath}`tests/integration/README.md` for more details on how to run the tests.
Note that each provider's `sample_run_config()` method (in the configuration class for that provider)
typically references some environment variables for specifying API keys and the like. You can set these in the environment or pass these via the `--env` flag to the test command.
### 2. Unit Testing
Unit tests are located in {repopath}`tests/unit`. Provider-specific unit tests are located in {repopath}`tests/unit/providers`. These tests are all run automatically as part of the CI process.
### 3. Additional end-to-end testing
### 3. End-to-End Testing
1. Start a Llama Stack server with your new provider 1. Start a Llama Stack server with your new provider
2. Test using client requests 2. Verify compatibility with existing client scripts in the [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main) repository
3. Verify compatibility with existing client scripts in the [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main) repository 3. Document which scripts are compatible with your provider
4. Document which scripts are compatible with your provider
## Submitting Your PR ## Submitting Your PR
1. Ensure all tests pass 1. Ensure all tests pass
2. Include a comprehensive test plan in your PR summary 2. Include a comprehensive test plan in your PR summary
3. Document any known limitations or considerations 3. Document any known limitations or considerations
4. Submit your pull request for review

View file

@ -4,6 +4,35 @@
This guide will walk you through the steps to get started with building a Llama Stack distribution from scratch with your choice of API providers. This guide will walk you through the steps to get started with building a Llama Stack distribution from scratch with your choice of API providers.
### Setting your log level
In order to specify the proper logging level users can apply the following environment variable `LLAMA_STACK_LOGGING` with the following format:
`LLAMA_STACK_LOGGING=server=debug;core=info`
Where each category in the following list:
- all
- core
- server
- router
- inference
- agents
- safety
- eval
- tools
- client
Can be set to any of the following log levels:
- debug
- info
- warning
- error
- critical
The default global log level is `info`. `all` sets the log level for all components.
### Llama Stack Build ### Llama Stack Build
In order to build your own distribution, we recommend you clone the `llama-stack` repository. In order to build your own distribution, we recommend you clone the `llama-stack` repository.
@ -30,7 +59,7 @@ Build a Llama stack container
options: options:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml. --config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml.
If this argument is not provided, you will be prompted to enter information interactively If this argument is not provided, you will be prompted to enter information interactively
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates --template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates
--list-templates Show the available templates for building a Llama Stack distribution --list-templates Show the available templates for building a Llama Stack distribution
@ -106,7 +135,7 @@ It would be best to start with a template and understand the structure of the co
llama stack build llama stack build
> Enter a name for your Llama Stack (e.g. my-local-stack): my-stack > Enter a name for your Llama Stack (e.g. my-local-stack): my-stack
> Enter the image type you want your Llama Stack to be built as (container or conda): conda > Enter the image type you want your Llama Stack to be built as (container or conda or venv): conda
Llama Stack is composed of several APIs working together. Let's select Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs. the provider types (implementations) you want to use for these APIs.
@ -187,14 +216,14 @@ usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-i
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}] [--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
config config
start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution. Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
positional arguments: positional arguments:
config Path to config file to use for the run config Path to config file to use for the run
options: options:
-h, --help show this help message and exit -h, --help show this help message and exit
--port PORT Port to run the server on. Defaults to 8321 --port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. Defaults to 8321
--image-name IMAGE_NAME --image-name IMAGE_NAME
Name of the image to run. Defaults to the current conda environment Name of the image to run. Defaults to the current conda environment
--disable-ipv6 Disable IPv6 support --disable-ipv6 Disable IPv6 support

View file

@ -23,7 +23,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `remote::wolfram-alpha`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -42,12 +42,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ llama model list --downloaded
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M ┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
``` ```
## Running the Distribution ## Running the Distribution

View file

@ -42,12 +42,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ llama model list --downloaded
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M ┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
``` ```
## Running the Distribution ## Running the Distribution

View file

@ -23,7 +23,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |

View file

@ -22,7 +22,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -36,7 +36,7 @@ The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) - `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`) - `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080}/v1`) - `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080/v1`)
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`) - `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`) - `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)

View file

@ -23,7 +23,7 @@ The `llamastack/distribution-together` distribution consists of the following pr
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |

View file

@ -38,7 +38,7 @@ The API is **exactly identical** for both clients.
:::{dropdown} Starting up the Llama Stack server :::{dropdown} Starting up the Llama Stack server
The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc. The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc.
To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image. To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image. If you'd like to build your own image or customize the configurations, please check out [this guide](../references/index.md).
Lets setup some environment variables that we will use in the rest of the guide. Lets setup some environment variables that we will use in the rest of the guide.
```bash ```bash
@ -184,7 +184,6 @@ from termcolor import cprint
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types import Document from llama_stack_client.types import Document
@ -241,13 +240,14 @@ client.tool_runtime.rag_tool.insert(
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
) )
agent_config = AgentConfig( rag_agent = Agent(
client,
model=os.environ["INFERENCE_MODEL"], model=os.environ["INFERENCE_MODEL"],
# Define instructions for the agent ( aka system prompt) # Define instructions for the agent ( aka system prompt)
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
enable_session_persistence=False, enable_session_persistence=False,
# Define tools available to the agent # Define tools available to the agent
toolgroups=[ tools=[
{ {
"name": "builtin::rag/knowledge_search", "name": "builtin::rag/knowledge_search",
"args": { "args": {
@ -256,12 +256,10 @@ agent_config = AgentConfig(
} }
], ],
) )
rag_agent = Agent(client, agent_config)
session_id = rag_agent.create_session("test-session") session_id = rag_agent.create_session("test-session")
user_prompts = [ user_prompts = [
"What are the top 5 topics that were explained? Only list succinct bullet points.", "How to optimize memory usage in torchtune? use the knowledge_search tool to get information.",
] ]
# Run the agent loop by calling the `create_turn` method # Run the agent loop by calling the `create_turn` method

View file

@ -1,7 +1,7 @@
```{admonition} News ```{admonition} News
:class: tip :class: tip
Llama Stack {{ llama_stack_version }} is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v{{ llama_stack_version }}) for more details. Llama Stack {{ llama_stack_version }} is now available! See the {{ llama_stack_version_link }} for more details.
``` ```
# Llama Stack # Llama Stack
@ -68,6 +68,7 @@ A number of "adapters" are available for some popular Inference and Vector Store
| FAISS | Single Node | | FAISS | Single Node |
| SQLite-Vec| Single Node | | SQLite-Vec| Single Node |
| Chroma | Hosted and Single Node | | Chroma | Hosted and Single Node |
| Milvus | Hosted and Single Node |
| Postgres (PGVector) | Hosted and Single Node | | Postgres (PGVector) | Hosted and Single Node |
| Weaviate | Hosted | | Weaviate | Hosted |

View file

@ -2,7 +2,7 @@
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include: The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.), - LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.), - Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.) - Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
Providers come in two flavors: Providers come in two flavors:
@ -36,7 +36,7 @@ Evaluates the outputs of the system.
Collects telemetry data from the system. Collects telemetry data from the system.
## Tool Runtime ## Tool Runtime
Is associated with the ToolGroup resouces. Is associated with the ToolGroup resouces.
## Vector IO ## Vector IO
@ -55,5 +55,6 @@ vector_io/sqlite-vec
vector_io/chromadb vector_io/chromadb
vector_io/pgvector vector_io/pgvector
vector_io/qdrant vector_io/qdrant
vector_io/milvus
vector_io/weaviate vector_io/weaviate
``` ```

View file

@ -1,10 +1,10 @@
--- ---
orphan: true orphan: true
--- ---
# Chroma # Chroma
[Chroma](https://www.trychroma.com/) is an inline and remote vector [Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database. database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
That means you're not limited to storing vectors in memory or in a separate service. That means you're not limited to storing vectors in memory or in a separate service.
## Features ## Features

View file

@ -3,7 +3,7 @@ orphan: true
--- ---
# Faiss # Faiss
[Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It [Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory. allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval. That means you'll get fast and efficient vector retrieval.
@ -29,5 +29,5 @@ You can install Faiss using pip:
pip install faiss-cpu pip install faiss-cpu
``` ```
## Documentation ## Documentation
See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for
more details about Faiss in general. more details about Faiss in general.

View file

@ -0,0 +1,31 @@
---
orphan: true
---
# Milvus
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly within a Milvus database.
That means you're not limited to storing vectors in memory or in a separate service.
## Features
- Easy to use
- Fully integrated with Llama Stack
## Usage
To use Milvus in your Llama Stack project, follow these steps:
1. Install the necessary dependencies.
2. Configure your Llama Stack project to use Milvus.
3. Start storing and querying vectors.
## Installation
You can install Milvus using pymilvus:
```bash
pip install pymilvus
```
## Documentation
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.

View file

@ -3,7 +3,7 @@ orphan: true
--- ---
# Postgres PGVector # Postgres PGVector
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It [PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory. allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval. That means you'll get fast and efficient vector retrieval.

View file

@ -3,7 +3,7 @@ orphan: true
--- ---
# Qdrant # Qdrant
[Qdrant](https://qdrant.tech/documentation/) is a remote vector database provider for Llama Stack. It [Qdrant](https://qdrant.tech/documentation/) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory. allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval. That means you'll get fast and efficient vector retrieval.

View file

@ -3,8 +3,8 @@ orphan: true
--- ---
# SQLite-Vec # SQLite-Vec
[SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It [SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It
allows you to store and query vectors directly within an SQLite database. allows you to store and query vectors directly within an SQLite database.
That means you're not limited to storing vectors in memory or in a separate service. That means you're not limited to storing vectors in memory or in a separate service.
## Features ## Features

View file

@ -1,10 +1,10 @@
--- ---
orphan: true orphan: true
--- ---
# Weaviate # Weaviate
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack. [Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
It allows you to store and query vectors directly within a Weaviate database. It allows you to store and query vectors directly within a Weaviate database.
That means you're not limited to storing vectors in memory or in a separate service. That means you're not limited to storing vectors in memory or in a separate service.
## Features ## Features
@ -27,7 +27,7 @@ To use Weaviate in your Llama Stack project, follow these steps:
## Installation ## Installation
To install Weaviate see the [Weaviate quickstart documentation](https://weaviate.io/developers/weaviate/quickstart). To install Weaviate see the [Weaviate quickstart documentation](https://weaviate.io/developers/weaviate/quickstart).
## Documentation ## Documentation
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general. See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.

View file

@ -24,19 +24,9 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
- Associated with `Benchmark` resource. - Associated with `Benchmark` resource.
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
![Eval Flow](./resources/eval-flow.png)
```{admonition} Note on Benchmark v.s. Application Evaluation
:class: tip
- **Benchmark Evaluation** is a well-defined eval-task consisting of `dataset` and `scoring_function`. The generation (inference or agent) will be done as part of evaluation.
- **Application Evaluation** assumes users already have app inputs & generated outputs. Evaluation will purely focus on scoring the generated outputs via scoring functions (e.g. LLM-as-judge).
```
## Evaluation Examples Walkthrough ## Evaluation Examples Walkthrough
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb)
It is best to open this notebook in Colab to follow along with the examples. It is best to open this notebook in Colab to follow along with the examples.
@ -63,20 +53,29 @@ eval_rows = ds.to_pandas().to_dict(orient="records")
- Run evaluate on the dataset - Run evaluate on the dataset
```python ```python
from rich.pretty import pprint
from tqdm import tqdm
SYSTEM_PROMPT_TEMPLATE = """ SYSTEM_PROMPT_TEMPLATE = """
You are an expert in Agriculture whose job is to answer questions from the user using images. You are an expert in {subject} whose job is to answer questions from the user using images.
First, reason about the correct answer. First, reason about the correct answer.
Then write the answer in the following format where X is exactly one of A,B,C,D: Then write the answer in the following format where X is exactly one of A,B,C,D:
Answer: X Answer: X
Make sure X is one of A,B,C,D. Make sure X is one of A,B,C,D.
If you are uncertain of the correct answer, guess the most likely one. If you are uncertain of the correct answer, guess the most likely one.
""" """
system_message = { system_message = {
"role": "system", "role": "system",
"content": SYSTEM_PROMPT_TEMPLATE, "content": SYSTEM_PROMPT_TEMPLATE.format(subject=subset),
} }
# register the evaluation benchmark task with the dataset and scoring function
client.benchmarks.register( client.benchmarks.register(
benchmark_id="meta-reference::mmmu", benchmark_id="meta-reference::mmmu",
dataset_id=f"mmmu-{subset}-{split}", dataset_id=f"mmmu-{subset}-{split}",
@ -87,14 +86,15 @@ response = client.eval.evaluate_rows(
benchmark_id="meta-reference::mmmu", benchmark_id="meta-reference::mmmu",
input_rows=eval_rows, input_rows=eval_rows,
scoring_functions=["basic::regex_parser_multiple_choice_answer"], scoring_functions=["basic::regex_parser_multiple_choice_answer"],
task_config={ benchmark_config={
"type": "benchmark",
"eval_candidate": { "eval_candidate": {
"type": "model", "type": "model",
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct", "model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"sampling_params": { "sampling_params": {
"strategy": { "strategy": {
"type": "greedy", "type": "top_p",
"temperature": 1.0,
"top_p": 0.95,
}, },
"max_tokens": 4096, "max_tokens": 4096,
"repeat_penalty": 1.0, "repeat_penalty": 1.0,
@ -103,6 +103,7 @@ response = client.eval.evaluate_rows(
}, },
}, },
) )
pprint(response)
``` ```
#### 1.2. Running SimpleQA #### 1.2. Running SimpleQA
@ -115,10 +116,9 @@ simpleqa_dataset_id = "huggingface::simpleqa"
_ = client.datasets.register( _ = client.datasets.register(
dataset_id=simpleqa_dataset_id, dataset_id=simpleqa_dataset_id,
provider_id="huggingface", provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/evals"}, url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"},
metadata={ metadata={
"path": "llamastack/evals", "path": "llamastack/simpleqa",
"name": "evals__simpleqa",
"split": "train", "split": "train",
}, },
dataset_schema={ dataset_schema={
@ -145,8 +145,7 @@ response = client.eval.evaluate_rows(
benchmark_id="meta-reference::simpleqa", benchmark_id="meta-reference::simpleqa",
input_rows=eval_rows.rows, input_rows=eval_rows.rows,
scoring_functions=["llm-as-judge::405b-simpleqa"], scoring_functions=["llm-as-judge::405b-simpleqa"],
task_config={ benchmark_config={
"type": "benchmark",
"eval_candidate": { "eval_candidate": {
"type": "model", "type": "model",
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct", "model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
@ -160,6 +159,7 @@ response = client.eval.evaluate_rows(
}, },
}, },
) )
pprint(response)
``` ```
@ -170,19 +170,17 @@ response = client.eval.evaluate_rows(
```python ```python
agent_config = { agent_config = {
"model": "meta-llama/Llama-3.1-405B-Instruct", "model": "meta-llama/Llama-3.3-70B-Instruct",
"instructions": "You are a helpful assistant", "instructions": "You are a helpful assistant that have access to tool to search the web. ",
"sampling_params": { "sampling_params": {
"strategy": { "strategy": {
"type": "greedy", "type": "top_p",
}, "temperature": 0.5,
}, "top_p": 0.9,
"tools": [
{
"type": "brave_search",
"engine": "tavily",
"api_key": userdata.get("TAVILY_SEARCH_API_KEY"),
} }
},
"toolgroups": [
"builtin::websearch",
], ],
"tool_choice": "auto", "tool_choice": "auto",
"tool_prompt_format": "json", "tool_prompt_format": "json",
@ -195,25 +193,22 @@ response = client.eval.evaluate_rows(
benchmark_id="meta-reference::simpleqa", benchmark_id="meta-reference::simpleqa",
input_rows=eval_rows.rows, input_rows=eval_rows.rows,
scoring_functions=["llm-as-judge::405b-simpleqa"], scoring_functions=["llm-as-judge::405b-simpleqa"],
task_config={ benchmark_config={
"type": "benchmark",
"eval_candidate": { "eval_candidate": {
"type": "agent", "type": "agent",
"config": agent_config, "config": agent_config,
}, },
}, },
) )
pprint(response)
``` ```
### 3. Agentic Application Dataset Scoring ### 3. Agentic Application Dataset Scoring
- Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb)
- In this example, we will work with an example RAG dataset and couple of scoring functions for evaluation. Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets.
- `llm-as-judge::base`: LLM-As-Judge with custom judge prompt & model.
- `braintrust::factuality`: Factuality scorer from [braintrust](https://github.com/braintrustdata/autoevals).
- `basic::subset_of`: Basic checking if generated answer is a subset of expected answer.
- Please checkout our [Llama Stack Playground](https://llama-stack.readthedocs.io/en/latest/playground/index.html) for an interactive interface to upload datasets and run scorings. In this example, we will work with an example RAG dataset you have built previously, label with an annotation, and use LLM-As-Judge with custom judge prompt for scoring. Please checkout our [Llama Stack Playground](https://llama-stack.readthedocs.io/en/latest/playground/index.html) for an interactive interface to upload datasets and run scorings.
```python ```python
judge_model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8" judge_model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8"
@ -317,28 +312,9 @@ The `BenchmarkConfig` are user specified config to define:
2. Optionally scoring function params to allow customization of scoring function behaviour. This is useful to parameterize generic scoring functions such as LLMAsJudge with custom `judge_model` / `judge_prompt`. 2. Optionally scoring function params to allow customization of scoring function behaviour. This is useful to parameterize generic scoring functions such as LLMAsJudge with custom `judge_model` / `judge_prompt`.
**Example Benchmark BenchmarkConfig** **Example BenchmarkConfig**
```json ```json
{ {
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": "Llama3.2-3B-Instruct",
"sampling_params": {
"strategy": {
"type": "greedy",
},
"max_tokens": 0,
"repetition_penalty": 1.0
}
}
}
```
**Example Application BenchmarkConfig**
```json
{
"type": "app",
"eval_candidate": { "eval_candidate": {
"type": "model", "type": "model",
"model": "Llama3.1-405B-Instruct", "model": "Llama3.1-405B-Instruct",

View file

@ -129,3 +129,35 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). **Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. > **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
## List the downloaded models
To list the downloaded models with the following command:
```
llama model list --downloaded
```
You should see a table like this:
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```

View file

@ -154,6 +154,38 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. > **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
## List the downloaded models
To list the downloaded models with the following command:
```
llama model list --downloaded
```
You should see a table like this:
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```
## Understand the models ## Understand the models
The `llama model` command helps you explore the models interface. The `llama model` command helps you explore the models interface.

View file

@ -294,8 +294,9 @@
" # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n", " # Initialize custom tool (ensure `WebSearchTool` is defined earlier in the notebook)\n",
" webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n", " webSearchTool = WebSearchTool(api_key=BRAVE_SEARCH_API_KEY)\n",
"\n", "\n",
" # Define the agent configuration, including the model and tool setup\n", " # Create an agent instance with the client and configuration\n",
" agent_config = AgentConfig(\n", " agent = Agent(\n",
" client, \n",
" model=MODEL_NAME,\n", " model=MODEL_NAME,\n",
" instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n", " instructions=\"\"\"You are a helpful assistant that responds to user queries with relevant information and cites sources when available.\"\"\",\n",
" sampling_params={\n", " sampling_params={\n",
@ -303,17 +304,12 @@
" \"type\": \"greedy\",\n", " \"type\": \"greedy\",\n",
" },\n", " },\n",
" },\n", " },\n",
" tools=[webSearchTool.get_tool_definition()],\n", " tools=[webSearchTool],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"python_list\",\n",
" input_shields=input_shields,\n", " input_shields=input_shields,\n",
" output_shields=output_shields,\n", " output_shields=output_shields,\n",
" enable_session_persistence=False,\n", " enable_session_persistence=False,\n",
" )\n", " )\n",
"\n", "\n",
" # Create an agent instance with the client and configuration\n",
" agent = Agent(client, agent_config, [webSearchTool])\n",
"\n",
" # Create a session for interaction and print the session ID\n", " # Create a session for interaction and print the session ID\n",
" session_id = agent.create_session(\"test-session\")\n", " session_id = agent.create_session(\"test-session\")\n",
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n", " print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",

View file

@ -110,12 +110,12 @@
"from llama_stack_client import LlamaStackClient\n", "from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
"\n", "\n",
"\n", "\n",
"async def agent_example():\n", "async def agent_example():\n",
" client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n", " client = LlamaStackClient(base_url=f\"http://{HOST}:{PORT}\")\n",
" agent_config = AgentConfig(\n", " agent = Agent(\n",
" client, \n",
" model=MODEL_NAME,\n", " model=MODEL_NAME,\n",
" instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n", " instructions=\"You are a helpful assistant! If you call builtin tools like brave search, follow the syntax brave_search.call(…)\",\n",
" sampling_params={\n", " sampling_params={\n",
@ -130,14 +130,7 @@
" \"api_key\": BRAVE_SEARCH_API_KEY,\n", " \"api_key\": BRAVE_SEARCH_API_KEY,\n",
" }\n", " }\n",
" ],\n", " ],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"function_tag\",\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=False,\n",
" )\n", " )\n",
"\n",
" agent = Agent(client, agent_config)\n",
" session_id = agent.create_session(\"test-session\")\n", " session_id = agent.create_session(\"test-session\")\n",
" print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n", " print(f\"Created session_id={session_id} for Agent({agent.agent_id})\")\n",
"\n", "\n",

View file

@ -73,7 +73,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
Open a new terminal and install `llama-stack`: Open a new terminal and install `llama-stack`:
```bash ```bash
conda activate ollama conda activate ollama
pip install llama-stack==0.1.0 pip install -U llama-stack
``` ```
--- ---

View file

@ -103,7 +103,6 @@
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import (\n", "from llama_stack_client.types.agent_create_params import (\n",
" AgentConfig,\n",
" AgentConfigToolSearchToolDefinition,\n", " AgentConfigToolSearchToolDefinition,\n",
")\n", ")\n",
"\n", "\n",
@ -117,7 +116,8 @@
") -> Agent:\n", ") -> Agent:\n",
" \"\"\"Create an agent with specified tools.\"\"\"\n", " \"\"\"Create an agent with specified tools.\"\"\"\n",
" print(\"Using the following model: \", model)\n", " print(\"Using the following model: \", model)\n",
" agent_config = AgentConfig(\n", " return Agent(\n",
" client, \n",
" model=model,\n", " model=model,\n",
" instructions=instructions,\n", " instructions=instructions,\n",
" sampling_params={\n", " sampling_params={\n",
@ -126,12 +126,7 @@
" },\n", " },\n",
" },\n", " },\n",
" tools=tools,\n", " tools=tools,\n",
" tool_choice=\"auto\",\n", " )\n"
" tool_prompt_format=\"json\",\n",
" enable_session_persistence=True,\n",
" )\n",
"\n",
" return Agent(client, agent_config)\n"
] ]
}, },
{ {
@ -360,9 +355,9 @@
" # Create the agent with the tool\n", " # Create the agent with the tool\n",
" weather_tool = WeatherTool()\n", " weather_tool = WeatherTool()\n",
"\n", "\n",
" agent_config = AgentConfig(\n", " agent = Agent(\n",
" client=client, \n",
" model=LLAMA31_8B_INSTRUCT,\n", " model=LLAMA31_8B_INSTRUCT,\n",
" # model=model_name,\n",
" instructions=\"\"\"\n", " instructions=\"\"\"\n",
" You are a weather assistant that can provide weather information.\n", " You are a weather assistant that can provide weather information.\n",
" Always specify the location clearly in your responses.\n", " Always specify the location clearly in your responses.\n",
@ -373,16 +368,9 @@
" \"type\": \"greedy\",\n", " \"type\": \"greedy\",\n",
" },\n", " },\n",
" },\n", " },\n",
" tools=[weather_tool.get_tool_definition()],\n", " tools=[weather_tool],\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n",
" input_shields=[],\n",
" output_shields=[],\n",
" enable_session_persistence=True,\n",
" )\n", " )\n",
"\n", "\n",
" agent = Agent(client=client, agent_config=agent_config, custom_tools=[weather_tool])\n",
"\n",
" return agent\n", " return agent\n",
"\n", "\n",
"\n", "\n",

View file

@ -41,16 +41,36 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
class Attachment(BaseModel): class Attachment(BaseModel):
"""An attachment to an agent turn.
:param content: The content of the attachment.
:param mime_type: The MIME type of the attachment.
"""
content: InterleavedContent | URL content: InterleavedContent | URL
mime_type: str mime_type: str
class Document(BaseModel): class Document(BaseModel):
"""A document to be used by an agent.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
"""
content: InterleavedContent | URL content: InterleavedContent | URL
mime_type: str mime_type: str
class StepCommon(BaseModel): class StepCommon(BaseModel):
"""A common step in an agent turn.
:param turn_id: The ID of the turn.
:param step_id: The ID of the step.
:param started_at: The time the step started.
:param completed_at: The time the step completed.
"""
turn_id: str turn_id: str
step_id: str step_id: str
started_at: Optional[datetime] = None started_at: Optional[datetime] = None
@ -58,6 +78,14 @@ class StepCommon(BaseModel):
class StepType(Enum): class StepType(Enum):
"""Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM.
:cvar tool_execution: The step is a tool execution step that executes a tool call.
:cvar shield_call: The step is a shield call step that checks for safety violations.
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
"""
inference = "inference" inference = "inference"
tool_execution = "tool_execution" tool_execution = "tool_execution"
shield_call = "shield_call" shield_call = "shield_call"
@ -66,6 +94,11 @@ class StepType(Enum):
@json_schema_type @json_schema_type
class InferenceStep(StepCommon): class InferenceStep(StepCommon):
"""An inference step in an agent turn.
:param model_response: The response from the LLM.
"""
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value step_type: Literal[StepType.inference.value] = StepType.inference.value
@ -74,6 +107,12 @@ class InferenceStep(StepCommon):
@json_schema_type @json_schema_type
class ToolExecutionStep(StepCommon): class ToolExecutionStep(StepCommon):
"""A tool execution step in an agent turn.
:param tool_calls: The tool calls to execute.
:param tool_responses: The tool responses from the tool calls.
"""
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
tool_calls: List[ToolCall] tool_calls: List[ToolCall]
tool_responses: List[ToolResponse] tool_responses: List[ToolResponse]
@ -81,13 +120,25 @@ class ToolExecutionStep(StepCommon):
@json_schema_type @json_schema_type
class ShieldCallStep(StepCommon): class ShieldCallStep(StepCommon):
"""A shield call step in an agent turn.
:param violation: The violation from the shield call.
"""
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
violation: Optional[SafetyViolation] violation: Optional[SafetyViolation]
@json_schema_type @json_schema_type
class MemoryRetrievalStep(StepCommon): class MemoryRetrievalStep(StepCommon):
"""A memory retrieval step in an agent turn.
:param vector_db_ids: The IDs of the vector databases to retrieve context from.
:param inserted_context: The context retrieved from the vector databases.
"""
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
# TODO: should this be List[str]?
vector_db_ids: str vector_db_ids: str
inserted_context: InterleavedContent inserted_context: InterleavedContent
@ -148,7 +199,7 @@ AgentToolGroup = register_schema(
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
input_shields: Optional[List[str]] = Field(default_factory=list) input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list) output_shields: Optional[List[str]] = Field(default_factory=list)
@ -296,16 +347,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
stream: Optional[bool] = False stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None tool_config: Optional[ToolConfig] = None
# TODO (xiyan): temporary flag, will remove for 0.1.5
allow_turn_resume: Optional[bool] = False
@json_schema_type @json_schema_type
class AgentTurnResumeRequest(BaseModel): class AgentTurnResumeRequest(BaseModel):
agent_id: str agent_id: str
session_id: str session_id: str
turn_id: str turn_id: str
tool_responses: List[ToolResponseMessage] tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]]
stream: Optional[bool] = False stream: Optional[bool] = False
@ -338,7 +386,13 @@ class Agents(Protocol):
async def create_agent( async def create_agent(
self, self,
agent_config: AgentConfig, agent_config: AgentConfig,
) -> AgentCreateResponse: ... ) -> AgentCreateResponse:
"""Create an agent with the given configuration.
:param agent_config: The configuration for the agent.
:returns: An AgentCreateResponse with the agent ID.
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST") @webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
async def create_agent_turn( async def create_agent_turn(
@ -355,8 +409,19 @@ class Agents(Protocol):
documents: Optional[List[Document]] = None, documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None, toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... """Create a new turn for an agent.
:param agent_id: The ID of the agent to create the turn for.
:param session_id: The ID of the session to create the turn for.
:param messages: List of messages to start the turn with.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param documents: (Optional) List of documents to create the turn with.
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
"""
@webmethod( @webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume", route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
@ -367,7 +432,7 @@ class Agents(Protocol):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
tool_responses: List[ToolResponseMessage], tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses. """Resume an agent turn with executed tool call responses.
@ -378,6 +443,7 @@ class Agents(Protocol):
:param session_id: The ID of the session to resume. :param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume. :param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with. :param tool_responses: The tool call responses to resume the turn with.
NOTE: ToolResponseMessage will be deprecated. Use ToolResponse.
:param stream: Whether to stream the response. :param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects. :returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
""" """
@ -392,7 +458,15 @@ class Agents(Protocol):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
) -> Turn: ... ) -> Turn:
"""Retrieve an agent turn by its ID.
:param agent_id: The ID of the agent to get the turn for.
:param session_id: The ID of the session to get the turn for.
:param turn_id: The ID of the turn to get.
:returns: A Turn.
"""
...
@webmethod( @webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}", route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
@ -404,14 +478,30 @@ class Agents(Protocol):
session_id: str, session_id: str,
turn_id: str, turn_id: str,
step_id: str, step_id: str,
) -> AgentStepResponse: ... ) -> AgentStepResponse:
"""Retrieve an agent step by its ID.
:param agent_id: The ID of the agent to get the step for.
:param session_id: The ID of the session to get the step for.
:param turn_id: The ID of the turn to get the step for.
:param step_id: The ID of the step to get.
:returns: An AgentStepResponse.
"""
...
@webmethod(route="/agents/{agent_id}/session", method="POST") @webmethod(route="/agents/{agent_id}/session", method="POST")
async def create_agent_session( async def create_agent_session(
self, self,
agent_id: str, agent_id: str,
session_name: str, session_name: str,
) -> AgentSessionCreateResponse: ... ) -> AgentSessionCreateResponse:
"""Create a new session for an agent.
:param agent_id: The ID of the agent to create the session for.
:param session_name: The name of the session to create.
:returns: An AgentSessionCreateResponse.
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET") @webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
async def get_agents_session( async def get_agents_session(
@ -419,17 +509,35 @@ class Agents(Protocol):
session_id: str, session_id: str,
agent_id: str, agent_id: str,
turn_ids: Optional[List[str]] = None, turn_ids: Optional[List[str]] = None,
) -> Session: ... ) -> Session:
"""Retrieve an agent session by its ID.
:param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by.
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE") @webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
async def delete_agents_session( async def delete_agents_session(
self, self,
session_id: str, session_id: str,
agent_id: str, agent_id: str,
) -> None: ... ) -> None:
"""Delete an agent session by its ID.
:param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for.
"""
...
@webmethod(route="/agents/{agent_id}", method="DELETE") @webmethod(route="/agents/{agent_id}", method="DELETE")
async def delete_agent( async def delete_agent(
self, self,
agent_id: str, agent_id: str,
) -> None: ... ) -> None:
"""Delete an agent by its ID.
:param agent_id: The ID of the agent to delete.
"""
...

View file

@ -40,7 +40,7 @@ class BatchInference(Protocol):
self, self,
model: str, model: str,
content_batch: List[InterleavedContent], content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ... ) -> BatchCompletionResponse: ...
@ -50,7 +50,7 @@ class BatchInference(Protocol):
self, self,
model: str, model: str,
messages_batch: List[List[Message]], messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list, tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,

View file

@ -14,6 +14,14 @@ from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type @json_schema_type
class PaginatedRowsResult(BaseModel): class PaginatedRowsResult(BaseModel):
"""
A paginated list of rows from a dataset.
:param rows: The rows in the current page.
:param total_count: The total number of rows in the dataset.
:param next_page_token: The token to get the next page of rows.
"""
# the rows obey the DatasetSchema for the given dataset # the rows obey the DatasetSchema for the given dataset
rows: List[Dict[str, Any]] rows: List[Dict[str, Any]]
total_count: int total_count: int
@ -36,7 +44,15 @@ class DatasetIO(Protocol):
rows_in_page: int, rows_in_page: int,
page_token: Optional[str] = None, page_token: Optional[str] = None,
filter_condition: Optional[str] = None, filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ... ) -> PaginatedRowsResult:
"""Get a paginated list of rows from a dataset.
:param dataset_id: The ID of the dataset to get the rows from.
:param rows_in_page: The number of rows to get per page.
:param page_token: The token to get the next page of rows.
:param filter_condition: (Optional) A condition to filter the rows by.
"""
...
@webmethod(route="/datasetio/rows", method="POST") @webmethod(route="/datasetio/rows", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ... async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...

View file

@ -5,6 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Optional
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -35,3 +38,20 @@ class Api(Enum):
# built-in API # built-in API
inspect = "inspect" inspect = "inspect"
@json_schema_type
class Error(BaseModel):
"""
Error response from the API. Roughly follows RFC 7807.
:param status: HTTP status code
:param title: Error title, a short summary of the error which is invariant for an error type
:param detail: Error detail, a longer human-readable description of the error
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
"""
status: int
title: str
detail: str
instance: Optional[str] = None

View file

@ -19,6 +19,13 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
@json_schema_type @json_schema_type
class ModelCandidate(BaseModel): class ModelCandidate(BaseModel):
"""A model candidate for evaluation.
:param model: The model ID to evaluate.
:param sampling_params: The sampling parameters for the model.
:param system_message: (Optional) The system message providing instructions or context to the model.
"""
type: Literal["model"] = "model" type: Literal["model"] = "model"
model: str model: str
sampling_params: SamplingParams sampling_params: SamplingParams
@ -27,6 +34,11 @@ class ModelCandidate(BaseModel):
@json_schema_type @json_schema_type
class AgentCandidate(BaseModel): class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent" type: Literal["agent"] = "agent"
config: AgentConfig config: AgentConfig
@ -39,6 +51,13 @@ EvalCandidate = register_schema(
@json_schema_type @json_schema_type
class BenchmarkConfig(BaseModel): class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
:param eval_candidate: The candidate to evaluate.
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
"""
eval_candidate: EvalCandidate eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field( scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run", description="Map between scoring function id and parameters for each scoring function you want to run",
@ -53,18 +72,32 @@ class BenchmarkConfig(BaseModel):
@json_schema_type @json_schema_type
class EvaluateResponse(BaseModel): class EvaluateResponse(BaseModel):
"""The response from an evaluation.
:param generations: The generations from the evaluation.
:param scores: The scores from the evaluation.
"""
generations: List[Dict[str, Any]] generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name # each key in the dict is a scoring function name
scores: Dict[str, ScoringResult] scores: Dict[str, ScoringResult]
class Eval(Protocol): class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST") @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
async def run_eval( async def run_eval(
self, self,
benchmark_id: str, benchmark_id: str,
task_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> Job: ... ) -> Job:
"""Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation.
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST") @webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows( async def evaluate_rows(
@ -72,14 +105,41 @@ class Eval(Protocol):
benchmark_id: str, benchmark_id: str,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: List[str], scoring_functions: List[str],
task_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ... ) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: ... async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE") @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: ... async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to cancel.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET") @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ... async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
"""

View file

@ -278,14 +278,14 @@ ResponseFormat = register_schema(
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
content: InterleavedContent content: InterleavedContent
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@json_schema_type @json_schema_type
class CompletionResponse(BaseModel): class CompletionResponse(MetricResponseMixin):
"""Response from a completion request. """Response from a completion request.
:param content: The generated completion text :param content: The generated completion text
@ -299,7 +299,7 @@ class CompletionResponse(BaseModel):
@json_schema_type @json_schema_type
class CompletionResponseStreamChunk(BaseModel): class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response. """A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens. :param delta: New content generated since last chunk. This can be one or more tokens.
@ -357,7 +357,7 @@ class ToolConfig(BaseModel):
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[Message] messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig) tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
@ -368,7 +368,7 @@ class ChatCompletionRequest(BaseModel):
@json_schema_type @json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): class ChatCompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed chat completion response. """A chunk of a streamed chat completion response.
:param event: The event containing the new content :param event: The event containing the new content
@ -378,7 +378,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
@json_schema_type @json_schema_type
class ChatCompletionResponse(MetricResponseMixin, BaseModel): class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request. """Response from a chat completion request.
:param completion_message: The complete response message :param completion_message: The complete response message
@ -444,7 +444,7 @@ class Inference(Protocol):
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
@ -467,7 +467,7 @@ class Inference(Protocol):
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,

View file

@ -17,6 +17,13 @@ ScoringResultRow = Dict[str, Any]
@json_schema_type @json_schema_type
class ScoringResult(BaseModel): class ScoringResult(BaseModel):
"""
A scoring result for a single row.
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
:param aggregated_results: Map of metric name to aggregated value
"""
score_rows: List[ScoringResultRow] score_rows: List[ScoringResultRow]
# aggregated metrics to value # aggregated metrics to value
aggregated_results: Dict[str, Any] aggregated_results: Dict[str, Any]
@ -30,6 +37,12 @@ class ScoreBatchResponse(BaseModel):
@json_schema_type @json_schema_type
class ScoreResponse(BaseModel): class ScoreResponse(BaseModel):
"""
The response from scoring.
:param results: A map of scoring function name to ScoringResult.
"""
# each key in the dict is a scoring function name # each key in the dict is a scoring function name
results: Dict[str, ScoringResult] results: Dict[str, ScoringResult]
@ -55,4 +68,11 @@ class Scoring(Protocol):
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]], scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse: ... ) -> ScoreResponse:
"""Score a list of rows.
:param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring.
:return: ScoreResponse object containing rows and aggregated results
"""
...

View file

@ -64,7 +64,7 @@ class ModelDescribe(Subcommand):
] ]
if model.recommended_sampling_params is not None: if model.recommended_sampling_params is not None:
sampling_params = model.recommended_sampling_params.dict() sampling_params = model.recommended_sampling_params.model_dump()
for k in ("max_tokens", "repetition_penalty"): for k in ("max_tokens", "repetition_penalty"):
del sampling_params[k] del sampling_params[k]
rows.append( rows.append(

View file

@ -7,10 +7,14 @@
import argparse import argparse
import textwrap import textwrap
from io import StringIO from io import StringIO
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
ROOT_DIR = Path(__file__).parent.parent
class ModelPromptFormat(Subcommand): class ModelPromptFormat(Subcommand):
"""Llama model cli for describe a model prompt format (message formats)""" """Llama model cli for describe a model prompt format (message formats)"""
@ -48,7 +52,26 @@ class ModelPromptFormat(Subcommand):
supported_model_ids = [ supported_model_ids = [
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2} m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
] ]
model_str = "\n".join([m.value for m in supported_model_ids])
model_list = [m.value for m in supported_model_ids]
model_str = "\n".join(model_list)
if args.list:
headers = ["Model(s)"]
rows = []
for m in model_list:
rows.append(
[
m,
]
)
print_table(
rows,
headers,
separate_rows=True,
)
return
try: try:
model_id = CoreModelId(args.model_name) model_id = CoreModelId(args.model_name)
except ValueError: except ValueError:
@ -57,9 +80,9 @@ class ModelPromptFormat(Subcommand):
if model_id not in supported_model_ids: if model_id not in supported_model_ids:
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}") self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
llama_3_1_file = importlib.resources.files("llama_models") / "llama3_1/prompt_format.md" llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
llama_3_2_text_file = importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md" llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
llama_3_2_vision_file = importlib.resources.files("llama_models") / "llama3_2/vision_prompt_format.md" llama_3_2_vision_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "vision_prompt_format.md"
if model_family(model_id) == ModelFamily.llama3_1: if model_family(model_id) == ModelFamily.llama3_1:
with importlib.resources.as_file(llama_3_1_file) as f: with importlib.resources.as_file(llama_3_1_file) as f:
content = f.open("r").read() content = f.open("r").read()

View file

@ -141,7 +141,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
completer=WordCompleter(available_providers), completer=WordCompleter(available_providers),
complete_while_typing=True, complete_while_typing=True,
validator=Validator.from_callable( validator=Validator.from_callable(
lambda x: x in available_providers, lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
error_message="Invalid provider, use <TAB> to see options", error_message="Invalid provider, use <TAB> to see options",
), ),
) )
@ -248,7 +248,7 @@ def _generate_run_config(
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class) config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
if hasattr(config_type, "sample_run_config"): if hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__=f"distributions/{image_name}") config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
else: else:
config = {} config = {}

View file

@ -26,7 +26,7 @@ class StackBuild(Subcommand):
"--config", "--config",
type=str, type=str,
default=None, default=None,
help="Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively", help="Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
) )
self.parser.add_argument( self.parser.add_argument(

View file

@ -37,7 +37,7 @@ class StackRun(Subcommand):
self.parser.add_argument( self.parser.add_argument(
"--port", "--port",
type=int, type=int,
help="Port to run the server on. Defaults to 8321", help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. Defaults to 8321",
default=int(os.getenv("LLAMA_STACK_PORT", 8321)), default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
) )
self.parser.add_argument( self.parser.add_argument(
@ -79,12 +79,8 @@ class StackRun(Subcommand):
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml import yaml
from llama_stack.distribution.build import ImageType
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import ( from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
BUILDS_BASE_DIR,
DISTRIBS_BASE_DIR,
)
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
config_file = Path(args.config) config_file = Path(args.config)
@ -97,14 +93,6 @@ class StackRun(Subcommand):
if config_file.exists(): if config_file.exists():
template_name = args.config template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to conda dir
config_file = Path(BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml")
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to container dir
config_file = Path(BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml")
if not config_file.exists() and not has_yaml_suffix: if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir # check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml") config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")

View file

@ -15,7 +15,6 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_command, run_with_pty from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -103,8 +102,6 @@ def build_image(
template_or_config, template_or_config,
image_name, image_name,
container_base, container_base,
str(build_file_path),
str(BUILDS_BASE_DIR / ImageType.container.value),
" ".join(normal_deps), " ".join(normal_deps),
] ]
elif build_config.image_type == ImageType.conda.value: elif build_config.image_type == ImageType.conda.value:

View file

@ -6,8 +6,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out # This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694 # Reference: https://github.com/astral-sh/uv/pull/1694
@ -16,8 +16,8 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR" echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR" echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi fi
if [ "$#" -lt 3 ]; then if [ "$#" -lt 3 ]; then
@ -52,7 +52,7 @@ ensure_conda_env_python310() {
local python_version="3.10" local python_version="3.10"
# Check if conda command is available # Check if conda command is available
if ! command -v conda &>/dev/null; then if ! is_command_available conda; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2 printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1 exit 1
fi fi
@ -87,8 +87,6 @@ ensure_conda_env_python310() {
# these packages are damaged in test-pypi, so install them first # these packages are damaged in test-pypi, so install them first
uv pip install fastapi libcst uv pip install fastapi libcst
uv pip install --extra-index-url https://test.pypi.org/simple/ \ uv pip install --extra-index-url https://test.pypi.org/simple/ \
llama-models==$TEST_PYPI_VERSION \
llama-stack-client==$TEST_PYPI_VERSION \
llama-stack==$TEST_PYPI_VERSION \ llama-stack==$TEST_PYPI_VERSION \
$pip_dependencies $pip_dependencies
if [ -n "$special_pip_deps" ]; then if [ -n "$special_pip_deps" ]; then
@ -111,22 +109,21 @@ ensure_conda_env_python310() {
else else
PYPI_VERSION="${PYPI_VERSION:-}" PYPI_VERSION="${PYPI_VERSION:-}"
if [ -n "$PYPI_VERSION" ]; then if [ -n "$PYPI_VERSION" ]; then
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}" SPEC_VERSION="llama-stack==${PYPI_VERSION}"
else else
SPEC_VERSION="llama-stack" SPEC_VERSION="llama-stack"
fi fi
uv pip install --no-cache-dir $SPEC_VERSION uv pip install --no-cache-dir $SPEC_VERSION
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2 printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
exit 1 exit 1
fi fi
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n" printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
uv pip uninstall llama-models uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
fi fi
# Install pip dependencies # Install pip dependencies

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
@ -6,7 +6,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-} LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
@ -20,26 +19,27 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead # mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-} USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
if [ "$#" -lt 6 ]; then if [ "$#" -lt 4 ]; then
# This only works for templates # This only works for templates
echo "Usage: $0 <template_or_config> <image_name> <container_base> <build_file_path> <host_build_dir> <pip_dependencies> [<special_pip_deps>]" >&2 echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
exit 1 exit 1
fi fi
set -euo pipefail set -euo pipefail
template_or_config="$1" template_or_config="$1"
image_name="$2" shift
container_base="$3" image_name="$1"
build_file_path="$4" shift
host_build_dir="$5" container_base="$1"
pip_dependencies="$6" shift
special_pip_deps="${7:-}" pip_dependencies="$1"
shift
special_pip_deps="${1:-}"
# Define color codes # Define color codes
RED='\033[0;31m' RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color NC='\033[0m' # No Color
CONTAINER_BINARY=${CONTAINER_BINARY:-docker} CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
@ -47,8 +47,10 @@ CONTAINER_OPTS=${CONTAINER_OPTS:-}
TEMP_DIR=$(mktemp -d) TEMP_DIR=$(mktemp -d)
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
add_to_container() { add_to_container() {
local input
output_file="$TEMP_DIR/Containerfile" output_file="$TEMP_DIR/Containerfile"
if [ -t 0 ]; then if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file" printf '%s\n' "$1" >>"$output_file"
@ -58,15 +60,21 @@ add_to_container() {
fi fi
} }
# Check if container command is available
if ! is_command_available $CONTAINER_BINARY; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1
fi
# Update and install UBI9 components if UBI9 base image is used # Update and install UBI9 components if UBI9 base image is used
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
add_to_container << EOF add_to_container << EOF
FROM $container_base FROM $container_base
WORKDIR /app WORKDIR /app
RUN microdnf -y update && microdnf install -y iputils net-tools wget \ RUN dnf -y update && dnf install -y iputils net-tools wget \
vim-minimal python3.11 python3.11-pip python3.11-wheel \ vim-minimal python3.11 python3.11-pip python3.11-wheel \
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1 ENV UV_SYSTEM_PYTHON=1
RUN pip install uv RUN pip install uv
@ -107,7 +115,6 @@ EOF
fi fi
stack_mount="/app/llama-stack-source" stack_mount="/app/llama-stack-source"
models_mount="/app/llama-models-source"
client_mount="/app/llama-stack-client-source" client_mount="/app/llama-stack-client-source"
install_local_package() { install_local_package() {
@ -131,10 +138,6 @@ EOF
} }
if [ -n "$LLAMA_MODELS_DIR" ]; then
install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR" install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
fi fi
@ -150,12 +153,12 @@ EOF
add_to_container << EOF add_to_container << EOF
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \ RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \ --index-strategy unsafe-best-match \
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
EOF EOF
else else
if [ -n "$PYPI_VERSION" ]; then if [ -n "$PYPI_VERSION" ]; then
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}" SPEC_VERSION="llama-stack==${PYPI_VERSION}"
else else
SPEC_VERSION="llama-stack" SPEC_VERSION="llama-stack"
fi fi
@ -165,6 +168,11 @@ EOF
fi fi
fi fi
# remove uv after installation
add_to_container << EOF
RUN pip uninstall -y uv
EOF
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag # if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
if [[ "$template_or_config" != *.yaml ]]; then if [[ "$template_or_config" != *.yaml ]]; then
add_to_container << EOF add_to_container << EOF
@ -185,26 +193,28 @@ RUN mkdir -p /.llama /.cache
RUN chmod -R g+rw /app /.llama /.cache RUN chmod -R g+rw /app /.llama /.cache
EOF EOF
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n" printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
cat $TEMP_DIR/Containerfile cat "$TEMP_DIR"/Containerfile
printf "\n" printf "\n"
mounts="" # Start building the CLI arguments
CLI_ARGS=()
# Read CONTAINER_OPTS and put it in an array
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount" CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
fi fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount" CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
fi fi
fi fi
if command -v selinuxenabled &>/dev/null && selinuxenabled; then if is_command_available selinuxenabled && selinuxenabled; then
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir # Disable SELinux labels -- we don't want to relabel the llama-stack source dir
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable" CLI_ARGS+=("--security-opt" "label=disable")
fi fi
# Set version tag based on PyPI version # Set version tag based on PyPI version
@ -212,7 +222,7 @@ if [ -n "$PYPI_VERSION" ]; then
version_tag="$PYPI_VERSION" version_tag="$PYPI_VERSION"
elif [ -n "$TEST_PYPI_VERSION" ]; then elif [ -n "$TEST_PYPI_VERSION" ]; then
version_tag="test-$TEST_PYPI_VERSION" version_tag="test-$TEST_PYPI_VERSION"
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_MODELS_DIR" ]]; then elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_STACK_CLIENT_DIR" ]]; then
version_tag="dev" version_tag="dev"
else else
URL="https://pypi.org/pypi/llama-stack/json" URL="https://pypi.org/pypi/llama-stack/json"
@ -225,11 +235,11 @@ image_tag="$image_name:$version_tag"
# Detect platform architecture # Detect platform architecture
ARCH=$(uname -m) ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then if [ -n "$BUILD_PLATFORM" ]; then
PLATFORM="--platform $BUILD_PLATFORM" CLI_ARGS+=("--platform $BUILD_PLATFORM")
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
PLATFORM="--platform linux/arm64" CLI_ARGS+=("--platform" "linux/arm64")
elif [ "$ARCH" = "x86_64" ]; then elif [ "$ARCH" = "x86_64" ]; then
PLATFORM="--platform linux/amd64" CLI_ARGS+=("--platform" "linux/amd64")
else else
echo "Unsupported architecture: $ARCH" echo "Unsupported architecture: $ARCH"
exit 1 exit 1
@ -238,8 +248,13 @@ fi
echo "PWD: $(pwd)" echo "PWD: $(pwd)"
echo "Containerfile: $TEMP_DIR/Containerfile" echo "Containerfile: $TEMP_DIR/Containerfile"
set -x set -x
$CONTAINER_BINARY build $CONTAINER_OPTS $PLATFORM -t $image_tag \
-f "$TEMP_DIR/Containerfile" "." $mounts --progress=plain $CONTAINER_BINARY build \
"${CLI_ARGS[@]}" \
-t "$image_tag" \
-f "$TEMP_DIR/Containerfile" \
"." \
--progress=plain
# clean up tmp/configs # clean up tmp/configs
set +x set +x

View file

@ -9,8 +9,8 @@
# TODO: combine this with build_conda_env.sh since it is almost identical # TODO: combine this with build_conda_env.sh since it is almost identical
# the only difference is that we don't do any conda-specific setup # the only difference is that we don't do any conda-specific setup
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out # This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694 # Reference: https://github.com/astral-sh/uv/pull/1694
@ -21,8 +21,8 @@ VIRTUAL_ENV=${VIRTUAL_ENV:-}
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR" echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR" echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi fi
if [ "$#" -lt 2 ]; then if [ "$#" -lt 2 ]; then
@ -95,7 +95,7 @@ run() {
# we are building a command line so word splitting is expected # we are building a command line so word splitting is expected
uv pip install --extra-index-url https://test.pypi.org/simple/ \ uv pip install --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \ --index-strategy unsafe-best-match \
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \ llama-stack=="$TEST_PYPI_VERSION" \
$pip_dependencies $pip_dependencies
if [ -n "$special_pip_deps" ]; then if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps" IFS='#' read -ra parts <<<"$special_pip_deps"
@ -120,15 +120,14 @@ run() {
uv pip install --no-cache-dir llama-stack uv pip install --no-cache-dir llama-stack
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_MODELS_DIR" >&2 printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
exit 1 exit 1
fi fi
printf "Installing from LLAMA_MODELS_DIR: %s\n" "$LLAMA_MODELS_DIR" printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
uv pip uninstall llama-models uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
fi fi
# Install pip dependencies # Install pip dependencies

View file

@ -39,7 +39,7 @@ def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provi
return Provider( return Provider(
provider_id=provider.provider_id, provider_id=provider.provider_id,
provider_type=provider.provider_type, provider_type=provider.provider_type,
config=cfg.dict(), config=cfg.model_dump(),
) )

View file

@ -1,47 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
set -euo pipefail
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 2 ]; then
echo "Usage: $0 <container name> <build file path>"
exit 1
fi
container_image="$1"
host_build_dir="$2"
container_build_dir="/app/builds"
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
# Disable SELinux labels
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
fi
mounts=""
if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
fi
set -x
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
--entrypoint "/usr/local/bin/llama" \
-v $host_build_dir:$container_build_dir \
$mounts \
$container_image \
stack configure ./llamastack-build.yaml --output-dir $container_build_dir

View file

@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]: def stack_apis() -> List[Api]:
return [v for v in Api] return list(Api)
class AutoRoutedApiInfo(BaseModel): class AutoRoutedApiInfo(BaseModel):
@ -59,7 +59,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]: def providable_apis() -> List[Api]:
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis()) routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect] return [api for api in Api if api not in routing_table_apis and api != Api.inspect]

View file

@ -104,7 +104,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
logger.warning( logger.warning(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
) )
return value raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
class LlamaStackAsLibraryClient(LlamaStackClient): class LlamaStackAsLibraryClient(LlamaStackClient):

View file

@ -5,9 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib import importlib
import inspect import inspect
import logging from typing import Any, Dict, List, Set, Tuple
from typing import Any, Dict, List, Set
from llama_stack import logcat
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
@ -53,8 +53,6 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate, VectorDBsProtocolPrivate,
) )
log = logging.getLogger(__name__)
class InvalidProviderError(Exception): class InvalidProviderError(Exception):
pass pass
@ -110,60 +108,43 @@ class ProviderWithSpec(Provider):
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]] ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls( async def resolve_impls(
run_config: StackRunConfig, run_config: StackRunConfig,
provider_registry: ProviderRegistry, provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
""" """
Does two things: Resolves provider implementations by:
- flatmaps, sorts and resolves the providers in dependency order 1. Validating and organizing providers.
- for each API, produces either a (local, passthrough or router) implementation 2. Sorting them in dependency order.
3. Instantiating them with required dependencies.
""" """
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis()) routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = set(x.router_api for x in builtin_automatically_routed_apis()) router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = {} providers_with_specs = validate_and_prepare_providers(
run_config, provider_registry, routing_table_apis, router_apis
for api_str, providers in run_config.providers.items(): )
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
specs = {}
for provider in providers:
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
log.error(p.deprecation_error, "red", attrs=["bold"])
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
log.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
spec = ProviderWithSpec(
spec=p,
**(provider.model_dump()),
)
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
apis_to_serve = run_config.apis or set( apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis] list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
) )
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
"""Generates specifications for automatically routed APIs."""
specs = {}
for info in builtin_automatically_routed_apis(): for info in builtin_automatically_routed_apis():
if info.router_api.value not in apis_to_serve: if info.router_api.value not in apis_to_serve:
continue continue
providers_with_specs[info.routing_table_api.value] = { specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec( "__builtin__": ProviderWithSpec(
provider_id="__routing_table__", provider_id="__routing_table__",
provider_type="__routing_table__", provider_type="__routing_table__",
@ -173,12 +154,12 @@ async def resolve_impls(
router_api=info.router_api, router_api=info.router_api,
module="llama_stack.distribution.routers", module="llama_stack.distribution.routers",
api_dependencies=[], api_dependencies=[],
deps__=([f"inner-{info.router_api.value}"]), deps__=[f"inner-{info.router_api.value}"],
), ),
) )
} }
providers_with_specs[info.router_api.value] = { specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec( "__builtin__": ProviderWithSpec(
provider_id="__autorouted__", provider_id="__autorouted__",
provider_type="__autorouted__", provider_type="__autorouted__",
@ -188,12 +169,69 @@ async def resolve_impls(
module="llama_stack.distribution.routers", module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api, routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api], api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]), # Add telemetry as an optional dependency to all auto-routed providers
optional_api_dependencies=[Api.telemetry],
deps__=([info.routing_table_api.value, Api.telemetry.value]),
), ),
) )
} }
return specs
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
def validate_and_prepare_providers(
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
) -> Dict[str, Dict[str, ProviderWithSpec]]:
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)
p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
spec = ProviderWithSpec(spec=p, **provider.model_dump())
specs[provider.provider_id] = spec
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
return providers_with_specs
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
"""Validates if the provider is allowed and handles deprecations."""
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
logcat.error("core", p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
logcat.warning(
"core",
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
def sort_providers_by_deps(
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
) -> List[Tuple[str, ProviderWithSpec]]:
"""Sorts providers based on their dependencies."""
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
# Append built-in "inspect" provider
apis = [x[1].spec.api for x in sorted_providers] apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append( sorted_providers.append(
( (
@ -201,28 +239,31 @@ async def resolve_impls(
ProviderWithSpec( ProviderWithSpec(
provider_id="__builtin__", provider_id="__builtin__",
provider_type="__builtin__", provider_type="__builtin__",
config={ config={"run_config": run_config.model_dump()},
"run_config": run_config.dict(),
},
spec=InlineProviderSpec( spec=InlineProviderSpec(
api=Api.inspect, api=Api.inspect,
provider_type="__builtin__", provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig", config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect", module="llama_stack.distribution.inspect",
api_dependencies=apis, api_dependencies=apis,
deps__=([x.value for x in apis]), deps__=[x.value for x in apis],
), ),
), ),
) )
) )
log.info(f"Resolved {len(sorted_providers)} providers") logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
log.info(f" {api_str} => {provider.provider_id}") logcat.debug("core", f" {api_str} => {provider.provider_id}")
log.info("") return sorted_providers
impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis} async def instantiate_providers(
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
) -> Dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: Dict[Api, Any] = {}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies} deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies: for a in provider.spec.optional_api_dependencies:
@ -233,14 +274,9 @@ async def resolve_impls(
if isinstance(provider.spec, RoutingTableProviderSpec): if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider( impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
provider,
deps, if api_str.startswith("inner-"):
inner_impls,
dist_registry,
)
# TODO: ugh slightly redesign this shady looking code
if "inner-" in api_str:
inner_impls_by_provider_id[api_str][provider.provider_id] = impl inner_impls_by_provider_id[api_str][provider.provider_id] = impl
else: else:
api = Api(api_str) api = Api(api_str)
@ -251,7 +287,7 @@ async def resolve_impls(
def topological_sort( def topological_sort(
providers_with_specs: Dict[str, List[ProviderWithSpec]], providers_with_specs: Dict[str, List[ProviderWithSpec]],
) -> List[ProviderWithSpec]: ) -> List[Tuple[str, ProviderWithSpec]]:
def dfs(kv, visited: Set[str], stack: List[str]): def dfs(kv, visited: Set[str], stack: List[str]):
api_str, providers = kv api_str, providers = kv
visited.add(api_str) visited.add(api_str)
@ -267,8 +303,8 @@ def topological_sort(
stack.append(api_str) stack.append(api_str)
visited = set() visited: Set[str] = set()
stack = [] stack: List[str] = []
for api_str, providers in providers_with_specs.items(): for api_str, providers in providers_with_specs.items():
if api_str not in visited: if api_str not in visited:
@ -278,13 +314,14 @@ def topological_sort(
for api_str in stack: for api_str in stack:
for provider in providers_with_specs[api_str]: for provider in providers_with_specs[api_str]:
flattened.append((api_str, provider)) flattened.append((api_str, provider))
return flattened return flattened
# returns a class implementing the protocol corresponding to the Api # returns a class implementing the protocol corresponding to the Api
async def instantiate_provider( async def instantiate_provider(
provider: ProviderWithSpec, provider: ProviderWithSpec,
deps: Dict[str, Any], deps: Dict[Api, Any],
inner_impls: Dict[str, Any], inner_impls: Dict[str, Any],
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
): ):
@ -292,8 +329,10 @@ async def instantiate_provider(
additional_protocols = additional_protocols_map() additional_protocols = additional_protocols_map()
provider_spec = provider.spec provider_spec = provider.spec
module = importlib.import_module(provider_spec.module) if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
module = importlib.import_module(provider_spec.module)
args = [] args = []
if isinstance(provider_spec, RemoteProviderSpec): if isinstance(provider_spec, RemoteProviderSpec):
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
@ -356,7 +395,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters) obj_params = set(obj_sig.parameters)
obj_params.discard("self") obj_params.discard("self")
if not (proto_params <= obj_params): if not (proto_params <= obj_params):
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch")) missing_methods.append((name, "signature_mismatch"))
else: else:
# Check if the method is actually implemented in the class # Check if the method is actually implemented in the class

View file

@ -47,7 +47,7 @@ async def get_routing_table_impl(
return impl return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
from .routers import ( from .routers import (
DatasetIORouter, DatasetIORouter,
EvalRouter, EvalRouter,
@ -69,9 +69,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"tool_runtime": ToolRuntimeRouter, "tool_runtime": ToolRuntimeRouter,
"preprocessing": PreprocessingRouter, "preprocessing": PreprocessingRouter,
} }
api_to_deps = {
"inference": {"telemetry": Api.telemetry},
}
if api.value not in api_to_routers: if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map") raise ValueError(f"API {api.value} not found in router map")
impl = api_to_routers[api.value](routing_table) api_to_dep_impl = {}
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
if dep_api in deps:
api_to_dep_impl[dep_name] = deps[dep_api]
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL, URL,
InterleavedContent, InterleavedContent,
@ -20,6 +22,10 @@ from llama_stack.apis.eval import (
JobStatus, JobStatus,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
Inference, Inference,
@ -27,13 +33,14 @@ from llama_stack.apis.inference import (
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
StopReason,
TextTruncation, TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.preprocessing import ( from llama_stack.apis.preprocessing import (
Preprocessing, Preprocessing,
PreprocessingDataElement, PreprocessingDataElement,
@ -49,6 +56,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams, ScoringFnParams,
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, Telemetry
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
RAGDocument, RAGDocument,
RAGQueryConfig, RAGQueryConfig,
@ -59,8 +67,10 @@ from llama_stack.apis.tools import (
) )
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.distribution.utils.chain import execute_preprocessor_chain from llama_stack.distribution.utils.chain import execute_preprocessor_chain
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format from llama_stack.providers.utils.telemetry.tracing import get_current_span
class VectorIORouter(VectorIO): class VectorIORouter(VectorIO):
@ -70,12 +80,15 @@ class VectorIORouter(VectorIO):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing VectorIORouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "VectorIORouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "VectorIORouter.shutdown")
pass pass
async def register_vector_db( async def register_vector_db(
@ -86,6 +99,10 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None,
) -> None: ) -> None:
logcat.debug(
"core",
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
)
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
vector_db_id, vector_db_id,
embedding_model, embedding_model,
@ -100,6 +117,10 @@ class VectorIORouter(VectorIO):
chunks: List[Chunk], chunks: List[Chunk],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
logcat.debug(
"core",
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks( async def query_chunks(
@ -108,6 +129,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -117,13 +139,21 @@ class InferenceRouter(Inference):
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
telemetry: Optional[Telemetry] = None,
) -> None: ) -> None:
logcat.debug("core", "Initializing InferenceRouter")
self.routing_table = routing_table self.routing_table = routing_table
self.telemetry = telemetry
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "InferenceRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "InferenceRouter.shutdown")
pass pass
async def register_model( async def register_model(
@ -134,13 +164,68 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> None: ) -> None:
logcat.debug(
"core",
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
) -> List[MetricEvent]:
span = get_current_span()
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
return metric_events
async def _compute_and_log_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
return metrics
async def _count_tokens(
self,
messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None, tool_choice: Optional[ToolChoice] = None,
@ -148,7 +233,13 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
logcat.debug(
"core",
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -167,8 +258,6 @@ class InferenceRouter(Inference):
params["tool_prompt_format"] = tool_prompt_format params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params) tool_config = ToolConfig(**params)
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
tools = tools or [] tools = tools or []
if tool_config.tool_choice == ToolChoice.none: if tool_config.tool_choice == ToolChoice.none:
tools = [] tools = []
@ -195,20 +284,63 @@ class InferenceRouter(Inference):
tool_config=tool_config, tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream: if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else: else:
return await provider.chat_completion(**params) response = await provider.chat_completion(**params)
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
content: InterleavedContent, content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
logcat.debug(
"core",
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -223,10 +355,41 @@ class InferenceRouter(Inference):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
prompt_tokens = await self._count_tokens(content)
if stream: if stream:
return (chunk async for chunk in await provider.completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else: else:
return await provider.completion(**params) response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def embeddings( async def embeddings(
self, self,
@ -236,6 +399,7 @@ class InferenceRouter(Inference):
output_dimension: Optional[int] = None, output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None, task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -255,12 +419,15 @@ class SafetyRouter(Safety):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing SafetyRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "SafetyRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "SafetyRouter.shutdown")
pass pass
async def register_shield( async def register_shield(
@ -270,6 +437,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield( async def run_shield(
@ -278,6 +446,7 @@ class SafetyRouter(Safety):
messages: List[Message], messages: List[Message],
params: Dict[str, Any] = None, params: Dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield( return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id, shield_id=shield_id,
messages=messages, messages=messages,
@ -290,12 +459,15 @@ class DatasetIORouter(DatasetIO):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing DatasetIORouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "DatasetIORouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "DatasetIORouter.shutdown")
pass pass
async def get_rows_paginated( async def get_rows_paginated(
@ -305,6 +477,10 @@ class DatasetIORouter(DatasetIO):
page_token: Optional[str] = None, page_token: Optional[str] = None,
filter_condition: Optional[str] = None, filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ) -> PaginatedRowsResult:
logcat.debug(
"core",
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
)
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated( return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=rows_in_page, rows_in_page=rows_in_page,
@ -313,6 +489,7 @@ class DatasetIORouter(DatasetIO):
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows( return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows=rows, rows=rows,
@ -324,12 +501,15 @@ class ScoringRouter(Scoring):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing ScoringRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "ScoringRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "ScoringRouter.shutdown")
pass pass
async def score_batch( async def score_batch(
@ -338,6 +518,7 @@ class ScoringRouter(Scoring):
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
@ -358,6 +539,10 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
logcat.debug(
"core",
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
)
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
@ -375,22 +560,26 @@ class EvalRouter(Eval):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing EvalRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "EvalRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "EvalRouter.shutdown")
pass pass
async def run_eval( async def run_eval(
self, self,
benchmark_id: str, benchmark_id: str,
task_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> Job: ) -> Job:
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval( return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
task_config=task_config, benchmark_config=benchmark_config,
) )
async def evaluate_rows( async def evaluate_rows(
@ -398,13 +587,14 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: List[str], scoring_functions: List[str],
task_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=input_rows, input_rows=input_rows,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
task_config=task_config, benchmark_config=benchmark_config,
) )
async def job_status( async def job_status(
@ -412,6 +602,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> Optional[JobStatus]: ) -> Optional[JobStatus]:
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel( async def job_cancel(
@ -419,6 +610,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> None: ) -> None:
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel( await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id, benchmark_id,
job_id, job_id,
@ -429,6 +621,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> EvaluateResponse: ) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result( return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id, benchmark_id,
job_id, job_id,
@ -441,6 +634,7 @@ class ToolRuntimeRouter(ToolRuntime):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table self.routing_table = routing_table
async def query( async def query(
@ -449,6 +643,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str], vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None, query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query( return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config content, vector_db_ids, query_config
) )
@ -460,6 +655,10 @@ class ToolRuntimeRouter(ToolRuntime):
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
preprocessor_chain: Optional[PreprocessorChain] = None, preprocessor_chain: Optional[PreprocessorChain] = None,
) -> None: ) -> None:
logcat.debug(
"core",
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert( return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens, preprocessor_chain documents, vector_db_id, chunk_size_in_tokens, preprocessor_chain
) )
@ -468,6 +667,7 @@ class ToolRuntimeRouter(ToolRuntime):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter")
self.routing_table = routing_table self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()" # HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -476,12 +676,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.shutdown")
pass pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool( return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name, tool_name=tool_name,
kwargs=kwargs, kwargs=kwargs,
@ -490,6 +693,7 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> List[ToolDef]:
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -318,13 +318,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if provider_vector_db_id is None: if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id provider_vector_db_id = vector_db_id
if provider_id is None: if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type if len(self.impls_by_provider_id) > 0:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else: else:
raise ValueError( raise ValueError("No provider available. Please configure a vector_io provider.")
"No provider specified and multiple providers available. Please specify a provider_id."
)
model = await self.get_object_by_identifier("model", embedding_model) model = await self.get_object_by_identifier("model", embedding_model)
if model is None: if model is None:
raise ValueError(f"Model {embedding_model} not found") raise ValueError(f"Model {embedding_model} not found")
@ -375,7 +376,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
else: else:
raise ValueError( raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id." f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
) )
if metadata is None: if metadata is None:
metadata = {} metadata = {}

View file

@ -26,9 +26,9 @@ from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack import logcat
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
@ -55,7 +55,7 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logger = logging.getLogger(__name__) logcat.init()
def warn_with_traceback(message, category, filename, lineno, file=None, line=None): def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -142,23 +142,23 @@ def handle_signal(app, signum, _) -> None:
not block the current execution. not block the current execution.
""" """
signame = signal.Signals(signum).name signame = signal.Signals(signum).name
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...") logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown(): async def shutdown():
try: try:
# Gracefully shut down implementations # Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values(): for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__ impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name) logcat.info("server", f"Shutting down {impl_name}")
try: try:
if hasattr(impl, "shutdown"): if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5) await asyncio.wait_for(impl.shutdown(), timeout=5)
else: else:
logger.warning("No shutdown method for %s", impl_name) logcat.warning("server", f"No shutdown method for {impl_name}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) logcat.exception("server", f"Shutdown timeout for {impl_name}")
except Exception as e: except Exception as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e}) logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
# Gather all running tasks # Gather all running tasks
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -172,7 +172,7 @@ def handle_signal(app, signum, _) -> None:
try: try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10) await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.exception("Timeout while waiting for tasks to finish") logcat.exception("server", "Timeout while waiting for tasks to finish")
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
finally: finally:
@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None:
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
logger.info("Starting up") logcat.info("server", "Starting up")
yield yield
logger.info("Shutting down") logcat.info("server", "Shutting down")
for impl in app.__llama_stack_impls__.values(): for impl in app.__llama_stack_impls__.values():
await impl.shutdown() await impl.shutdown()
@ -209,10 +209,11 @@ async def sse_generator(event_gen):
yield create_sse_event(item) yield create_sse_event(item)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
except asyncio.CancelledError: except asyncio.CancelledError:
print("Generator cancelled") logcat.info("server", "Generator cancelled")
await event_gen.aclose() await event_gen.aclose()
except Exception as e: except Exception as e:
traceback.print_exception(e) logcat.exception("server", f"Error in sse_generator: {e}")
logcat.exception("server", f"Traceback: {''.join(traceback.format_exception(type(e), e, e.__traceback__))}")
yield create_sse_event( yield create_sse_event(
{ {
"error": { "error": {
@ -234,7 +235,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
value = func(**kwargs) value = func(**kwargs)
return await maybe_await(value) return await maybe_await(value)
except Exception as e: except Exception as e:
traceback.print_exception(e) logcat.exception("server", f"Error in {func.__name__}")
raise translate_exception(e) from e raise translate_exception(e) from e
sig = inspect.signature(func) sig = inspect.signature(func)
@ -313,6 +314,8 @@ class ClientVersionMiddleware:
def main(): def main():
logcat.init()
"""Start the LlamaStack server.""" """Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.") parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument( parser.add_argument(
@ -352,10 +355,10 @@ def main():
for env_pair in args.env: for env_pair in args.env:
try: try:
key, value = validate_env_pair(env_pair) key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}") logcat.info("server", f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value os.environ[key] = value
except ValueError as e: except ValueError as e:
logger.error(f"Error: {str(e)}") logcat.error("server", f"Error: {str(e)}")
sys.exit(1) sys.exit(1)
if args.yaml_config: if args.yaml_config:
@ -363,12 +366,12 @@ def main():
config_file = Path(args.yaml_config) config_file = Path(args.yaml_config)
if not config_file.exists(): if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist") raise ValueError(f"Config file {config_file} does not exist")
logger.info(f"Using config file: {config_file}") logcat.info("server", f"Using config file: {config_file}")
elif args.template: elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml" config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists(): if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist") raise ValueError(f"Template {args.template} does not exist")
logger.info(f"Using template {args.template} config file: {config_file}") logcat.info("server", f"Using template {args.template} config file: {config_file}")
else: else:
raise ValueError("Either --yaml-config or --template must be provided") raise ValueError("Either --yaml-config or --template must be provided")
@ -376,9 +379,10 @@ def main():
config = replace_env_vars(yaml.safe_load(fp)) config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config) config = StackRunConfig(**config)
logger.info("Run configuration:") logcat.info("server", "Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump()) safe_config = redact_sensitive_fields(config.model_dump())
logger.info(yaml.dump(safe_config, indent=2)) for log_line in yaml.dump(safe_config, indent=2).split("\n"):
logcat.info("server", log_line)
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware) app.add_middleware(TracingMiddleware)
@ -388,7 +392,7 @@ def main():
try: try:
impls = asyncio.run(construct_stack(config)) impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e: except InvalidProviderError as e:
logger.error(f"Error: {str(e)}") logcat.error("server", f"Error: {str(e)}")
sys.exit(1) sys.exit(1)
if Api.telemetry in impls: if Api.telemetry in impls:
@ -433,11 +437,8 @@ def main():
) )
) )
logger.info(f"Serving API {api_str}") logcat.debug("server", f"serving APIs: {apis_to_serve}")
for endpoint in endpoints:
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
print("")
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_signal, app)) signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
@ -463,10 +464,10 @@ def main():
"ssl_keyfile": keyfile, "ssl_keyfile": keyfile,
"ssl_certfile": certfile, "ssl_certfile": certfile,
} }
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0" listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
logger.info(f"Listening on {listen_host}:{port}") logcat.info("server", f"Listening on {listen_host}:{port}")
uvicorn_config = { uvicorn_config = {
"app": app, "app": app,

View file

@ -5,14 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib.resources import importlib.resources
import logging
import os import os
import re import re
import tempfile
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import yaml import yaml
from termcolor import colored from termcolor import colored
from llama_stack import logcat
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
@ -35,14 +36,13 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
class LlamaStack( class LlamaStack(
VectorDBs, VectorDBs,
@ -106,12 +106,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process: for obj in objects_to_process:
log.info( logcat.debug(
"core",
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
) )
log.info("")
class EnvVarError(Exception): class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""): def __init__(self, var_name: str, path: str = ""):
@ -160,18 +159,34 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
return result return result
elif isinstance(config, str): elif isinstance(config, str):
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}" # Updated pattern to support both default values (:) and conditional values (+)
pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}"
def get_env_var(match): def get_env_var(match):
env_var = match.group(1) env_var = match.group(1)
default_val = match.group(2) operator = match.group(2) # ':' for default, '+' for conditional
value_expr = match.group(3)
value = os.environ.get(env_var) env_value = os.environ.get(env_var)
if not value:
if default_val is None: if operator == ":": # Default value syntax: ${env.FOO:default}
raise EnvVarError(env_var, path) if not env_value:
if value_expr is None:
raise EnvVarError(env_var, path)
else:
value = value_expr
else: else:
value = default_val value = env_value
elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set}
if env_value:
value = value_expr
else:
# If env var is not set, return empty string for the conditional case
value = ""
else: # No operator case: ${env.FOO}
if not env_value:
raise EnvVarError(env_var, path)
value = env_value
# expand "~" from the values # expand "~" from the values
return os.path.expanduser(value) return os.path.expanduser(value)
@ -220,3 +235,53 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
run_config = yaml.safe_load(path.open()) run_config = yaml.safe_load(path.open())
return StackRunConfig(**replace_env_vars(run_config)) return StackRunConfig(**replace_env_vars(run_config))
def run_config_from_adhoc_config_spec(
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
) -> StackRunConfig:
"""
Create an adhoc distribution from a list of API providers.
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
"""
api_providers = adhoc_config_spec.replace(";", ",").split(",")
provider_registry = provider_registry or get_provider_registry()
distro_dir = tempfile.mkdtemp()
provider_configs_by_api = {}
for api_provider in api_providers:
api_str, provider = api_provider.split("=")
api = Api(api_str)
providers_by_type = provider_registry[api]
provider_spec = providers_by_type.get(provider)
if not provider_spec:
provider_spec = providers_by_type.get(f"inline::{provider}")
if not provider_spec:
provider_spec = providers_by_type.get(f"remote::{provider}")
if not provider_spec:
raise ValueError(
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
)
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_configs_by_api[api_str] = [
Provider(
provider_id=provider,
provider_type=provider_spec.provider_type,
config=provider_config,
)
]
config = StackRunConfig(
image_name="distro-test",
apis=list(provider_configs_by_api.keys()),
providers=provider_configs_by_api,
)
return config

View file

@ -98,15 +98,20 @@ case "$env_type" in
*) *)
esac esac
set -x
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
$PYTHON_BINARY -m llama_stack.distribution.server.server \ $PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \ --yaml-config "$yaml_config" \
--port "$port" \ --port "$port" \
$env_vars \ $env_vars \
$other_args $other_args
elif [[ "$env_type" == "container" ]]; then elif [[ "$env_type" == "container" ]]; then
# Check if container command is available
if ! is_command_available $CONTAINER_BINARY; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1
fi
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
# Disable SELinux labels # Disable SELinux labels
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable" CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
@ -136,6 +141,8 @@ elif [[ "$env_type" == "container" ]]; then
version_tag=$(curl -s $URL | jq -r '.info.version') version_tag=$(curl -s $URL | jq -r '.info.version')
fi fi
set -x
$CONTAINER_BINARY run $CONTAINER_OPTS -it \ $CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \ -p $port:$port \
$env_vars \ $env_vars \

View file

@ -1,72 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <venv_path> <yaml_config> <port> <script_args...>"
exit 1
fi
venv_path="$1"
shift
yaml_config="$1"
shift
port="$1"
shift
# Initialize env_vars as an empty array
env_vars=""
other_args=""
# Process environment variables from --env arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac
done
echo "Using virtual environment: $venv_path"
# Activate virtual environment
if [ ! -d "$venv_path" ]; then
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
exit 1
fi
source "$venv_path/bin/activate"
set -x
python -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
--port "$port" \
$env_vars \
$other_args

View file

@ -17,7 +17,7 @@ llama stack run together
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page). 2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
```bash ```bash
$ llama-stack-client datasets register \ llama-stack-client datasets register \
--dataset-id "mmlu" \ --dataset-id "mmlu" \
--provider-id "huggingface" \ --provider-id "huggingface" \
--url "https://huggingface.co/datasets/llamastack/evals" \ --url "https://huggingface.co/datasets/llamastack/evals" \
@ -26,7 +26,7 @@ $ llama-stack-client datasets register \
``` ```
```bash ```bash
$ llama-stack-client benchmarks register \ llama-stack-client benchmarks register \
--eval-task-id meta-reference-mmlu \ --eval-task-id meta-reference-mmlu \
--provider-id meta-reference \ --provider-id meta-reference \
--dataset-id mmlu \ --dataset-id mmlu \

View file

@ -212,7 +212,7 @@ def run_evaluation_3():
benchmark_id=selected_benchmark, benchmark_id=selected_benchmark,
input_rows=[r], input_rows=[r],
scoring_functions=benchmarks[selected_benchmark].scoring_functions, scoring_functions=benchmarks[selected_benchmark].scoring_functions,
task_config=benchmark_config, benchmark_config=benchmark_config,
) )
for k in r.keys(): for k in r.keys():

View file

@ -7,7 +7,6 @@
import streamlit as st import streamlit as st
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.memory_insert_params import Document
from modules.api import llama_stack_api from modules.api import llama_stack_api
from modules.utils import data_url_from_file from modules.utils import data_url_from_file
@ -124,26 +123,22 @@ def rag_chat_page():
else: else:
strategy = {"type": "greedy"} strategy = {"type": "greedy"}
agent_config = AgentConfig( agent = Agent(
llama_stack_api.client,
model=selected_model, model=selected_model,
instructions=system_prompt, instructions=system_prompt,
sampling_params={ sampling_params={
"strategy": strategy, "strategy": strategy,
}, },
toolgroups=[ tools=[
dict( dict(
name="builtin::rag/knowledge_search", name="builtin::rag/knowledge_search",
args={ args={
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs], "vector_db_ids": list(selected_vector_dbs),
}, },
) )
], ],
tool_choice="auto",
tool_prompt_format="json",
enable_session_persistence=False,
) )
agent = Agent(llama_stack_api.client, agent_config)
session_id = agent.create_session("rag-session") session_id = agent.create_session("rag-session")
# Chat input # Chat input

View file

@ -13,6 +13,4 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints" DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime" RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"

204
llama_stack/logcat.py Normal file
View file

@ -0,0 +1,204 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Category-based logging utility for llama-stack.
This module provides a wrapper over the standard Python logging module that supports
categorized logging with environment variable control.
Usage:
from llama_stack import logcat
logcat.info("server", "Starting up...")
logcat.debug("inference", "Processing request...")
Environment variable:
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
Example: "server=debug;inference=warning"
"""
import datetime
import logging
import os
from typing import Dict
# ANSI color codes for terminal output
COLORS = {
"RESET": "\033[0m",
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[35m", # Magenta
"DIM": "\033[2m", # Dimmed text
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
}
# Static list of valid categories representing various parts of the Llama Stack
# server codebase
CATEGORIES = [
"core",
"server",
"router",
"inference",
"agents",
"safety",
"eval",
"tools",
"client",
]
_logger = logging.getLogger("llama_stack")
_logger.propagate = False
_default_level = logging.INFO
# Category-level mapping (can be modified by environment variables)
_category_levels: Dict[str, int] = {}
class TerminalStreamHandler(logging.StreamHandler):
def __init__(self, stream=None):
super().__init__(stream)
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
def format(self, record):
record.is_tty = self.is_tty
return super().format(record)
class ColoredFormatter(logging.Formatter):
"""Custom formatter with colors and fixed-width level names"""
def format(self, record):
levelname = record.levelname
# Use only time with milliseconds, not date
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
file_info = f"{record.filename}:{record.lineno}"
# Get category from extra if available
category = getattr(record, "category", None)
msg = record.getMessage()
if getattr(record, "is_tty", False):
color = COLORS.get(levelname, COLORS["RESET"])
if category:
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
f"{file_info:<20} {category_formatted}{msg}"
)
else:
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
f"{file_info:<20} {msg}"
)
else:
if category:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
else:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
return formatted_msg
def init(default_level: int = logging.INFO) -> None:
global _default_level, _category_levels, _logger
_default_level = default_level
_logger.setLevel(logging.DEBUG)
_logger.handlers = [] # Clear existing handlers
# Add our custom handler with the colored formatter
handler = TerminalStreamHandler()
formatter = ColoredFormatter()
handler.setFormatter(formatter)
_logger.addHandler(handler)
for category in CATEGORIES:
_category_levels[category] = default_level
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().lower()
level_value = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"warn": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}.get(level)
if level_value is None:
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
continue
if category == "all":
for cat in CATEGORIES:
_category_levels[cat] = level_value
else:
if category in CATEGORIES:
_category_levels[category] = level_value
else:
_logger.warning(f"Unknown logging category: {category}")
except ValueError:
_logger.warning(f"Invalid logging configuration: {pair}")
def _should_log(level: int, category: str) -> bool:
category = category.lower()
if category not in _category_levels:
return False
category_level = _category_levels[category]
return level >= category_level
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
if _should_log(level, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
def debug(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
def info(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.INFO, "info", category, msg, *args, **kwargs)
def warning(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
def warn(category: str, msg: str, *args, **kwargs) -> None:
warning(category, msg, *args, **kwargs)
def error(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
def critical(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
def exception(category: str, msg: str, *args, **kwargs) -> None:
if _should_log(logging.ERROR, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
_logger.exception(msg, *args, stacklevel=2, **kwargs)

View file

@ -11,16 +11,128 @@
# top-level folder for each specific model found within the models/ directory at # top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree. # the top-level of this source tree.
import base64
from enum import Enum from enum import Enum
from typing import Any, Dict, Literal, Optional, Union from io import BytesIO
from typing import Any, Dict, List, Literal, Optional, Union
# import all for backwards compatibility from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from llama_models.datatypes import * # noqa: F403
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema from llama_stack.schema_utils import json_schema_type, register_schema
# The goal is that these set of types are relevant for all Llama models.
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
# the llama3 series of models.
class Role(Enum):
system = "system"
user = "user"
assistant = "assistant"
tool = "tool"
class BuiltinTool(Enum):
brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha"
photogen = "photogen"
code_interpreter = "code_interpreter"
Primitive = Union[str, int, float, bool, None]
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
arguments: Dict[str, RecursiveType]
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class ToolPromptFormat(Enum):
"""Prompt format for calling custom / zero shot tools.
:cvar json: JSON format for calling tools. It takes the form:
{
"type": "function",
"function" : {
"name": "function_name",
"description": "function_description",
"parameters": {...}
}
}
:cvar function_tag: Function tag format, pseudo-XML. This looks like:
<function=function_name>(parameters)</function>
:cvar python_list: Python list. The output is a valid Python expression that can be
evaluated to a list. Each element in the list is a function call. Example:
["function_name(param1, param2)", "function_name(param1, param2)"]
"""
json = "json"
function_tag = "function_tag"
python_list = "python_list"
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
class RawMediaItem(BaseModel):
type: Literal["image"] = "image"
data: bytes | BytesIO
model_config = ConfigDict(arbitrary_types_allowed=True)
@field_serializer("data")
def serialize_data(self, data: Optional[bytes], _info):
if data is None:
return None
return base64.b64encode(data).decode("utf-8")
@field_validator("data", mode="before")
@classmethod
def validate_data(cls, v):
if isinstance(v, str):
return base64.b64decode(v)
return v
class RawTextItem(BaseModel):
type: Literal["text"] = "text"
text: str
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
RawContent = str | RawContentItem | List[RawContentItem]
class RawMessage(BaseModel):
role: Literal["user"] | Literal["system"] | Literal["tool"] | Literal["assistant"]
content: RawContent
# This is for RAG but likely should be absorbed into content
context: Optional[RawContent] = None
# These are for the output message coming from the assistant
stop_reason: Optional[StopReason] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
register_schema(ToolCall) register_schema(ToolCall)

View file

@ -0,0 +1,282 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import io
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from PIL import Image as PIL_Image
from llama_stack.models.llama.datatypes import (
BuiltinTool,
RawContent,
RawMediaItem,
RawMessage,
RawTextItem,
Role,
StopReason,
ToolCall,
ToolPromptFormat,
)
from .tokenizer import Tokenizer
from .tool_utils import ToolUtils
@dataclass
class VisionInput:
mask: List[List[int]]
images: List[PIL_Image.Image]
@dataclass
class LLMInput:
tokens: List[int]
vision: Optional[VisionInput] = None
def role_str(role: Role) -> str:
role_strs = {
Role.user: "user",
Role.system: "system",
Role.tool: "ipython", # special
Role.assistant: "assistant",
}
return role_strs[role]
class ChatFormat:
possible_headers: Dict[Role, str]
def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
def _encode_header(self, role: str) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
def encode_content(self, content: RawContent) -> LLMInput:
tokens, images = self._encode_content(content, bos=True)
return self._model_input_from_tokens_images(tokens, images)
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
tokens = []
images = []
added_bos = False
def _process(c):
nonlocal added_bos, bos
if isinstance(c, str) or isinstance(c, RawTextItem):
if isinstance(c, RawTextItem):
c = c.text
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
added_bos = True
elif isinstance(c, RawMediaItem):
bos = False if added_bos else bos
if bos:
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
added_bos = True
tokens.append(self.vision_token)
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
image = image.convert("RGB")
images.append(image)
if isinstance(content, list):
for c in content:
_process(c)
else:
_process(content)
return tokens, images
def encode_message(
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
) -> Tuple[List[int], List[PIL_Image.Image]]:
tokens = self._encode_header(message.role)
images = []
def _process_content(c):
toks, imgs = self._encode_content(c)
tokens.extend(toks)
images.extend(imgs)
if (
message.role == "assistant"
and len(message.tool_calls) > 0
and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter
):
tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])
_process_content(message.content)
if message.role == "user" and message.context is not None:
# This is RAG context; why is it here in the chat format? I don't think
# this is needed and can be moved upwards
_process_content("\n\n")
_process_content(message.context)
if message.role == "assistant":
for t in message.tool_calls:
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
_process_content(content)
eom = False
if message.role == "assistant":
eom = message.stop_reason == StopReason.end_of_message
tokens.append(self.tokenizer.special_tokens["<|eom_id|>" if eom else "<|eot_id|>"])
return tokens, images
def encode_dialog_prompt(
self,
messages: List[RawMessage],
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> LLMInput:
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
tokens = []
images = []
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
for message in messages:
toks, imgs = self.encode_message(message, tool_prompt_format)
tokens.extend(toks)
images.extend(imgs)
# Add the start of an assistant message for the model to complete.
tokens.extend(self._encode_header("assistant"))
return self._model_input_from_tokens_images(tokens, images)
# TODO(this should be generic, not only for assistant messages)
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
content = self.tokenizer.decode(tokens)
return self.decode_assistant_message_from_content(content, stop_reason)
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
content = content.strip(" ")
header_str = self.possible_headers[Role.assistant]
if content.startswith(header_str):
content = content[len(header_str) :]
ipython = content.startswith("<|python_tag|>")
if ipython:
content = content[len("<|python_tag|>") :]
if content.endswith("<|eot_id|>"):
content = content[: -len("<|eot_id|>")]
stop_reason = StopReason.end_of_turn
elif content.endswith("<|eom_id|>"):
content = content[: -len("<|eom_id|>")]
stop_reason = StopReason.end_of_message
tool_name = None
tool_arguments = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None:
tool_name, tool_arguments = custom_tool_info
# Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
tool_name, query = builtin_tool_info
tool_arguments = {
"query": query,
}
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = []
if tool_name is not None and tool_arguments is not None:
call_id = str(uuid.uuid4())
tool_calls.append(
ToolCall(
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
)
)
content = ""
return RawMessage(
role="assistant",
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
vision_input = None
if len(images) > 0:
vision_input = VisionInput(
mask=create_vision_mask(tokens, self.vision_token),
images=images,
)
return LLMInput(
tokens=[128256 if token == self.vision_token else token for token in tokens],
vision=vision_input,
)
def create_vision_mask(
tokens: List[int],
vision_token: int,
) -> List[List[int]]:
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
if len(vision_token_locations) == 0:
return []
if len(vision_token_locations) == 1:
# only one image present, unmask until end of sequence
return [[vision_token_locations[0], -1]]
vision_masks = [
[loc1, loc2] for loc1, loc2 in zip(vision_token_locations[:-1], vision_token_locations[1:], strict=False)
]
# last image will attend to all subsequent text
vision_masks.append([vision_token_locations[-1], len(tokens)])
# if there are two or more consecutive vision tokens,
# they should all attend to all subsequent
# text present
last_mask_end = vision_masks[-1][1]
for vision_mask in vision_masks[::-1]:
if vision_mask[0] == vision_mask[1] - 1:
vision_mask[1] = last_mask_end
last_mask_end = vision_mask[1]
return vision_masks

View file

@ -14,20 +14,19 @@
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
from llama_models.datatypes import ( from termcolor import colored
from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
RawMessage, RawMessage,
StopReason, StopReason,
ToolCall, ToolCall,
ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from termcolor import colored
from llama_stack.models.llama.datatypes import ToolDefinition
from . import template_data from . import template_data
from .chat_format import ChatFormat
from .prompt_templates import ( from .prompt_templates import (
BuiltinToolGenerator, BuiltinToolGenerator,
FunctionTagCustomToolGenerator, FunctionTagCustomToolGenerator,
@ -35,6 +34,7 @@ from .prompt_templates import (
SystemDefaultGenerator, SystemDefaultGenerator,
ToolResponseGenerator, ToolResponseGenerator,
) )
from .tokenizer import Tokenizer
THIS_DIR = Path(__file__).parent THIS_DIR = Path(__file__).parent

View file

@ -15,11 +15,8 @@ import textwrap
from datetime import datetime from datetime import datetime
from typing import Any, List, Optional from typing import Any, List, Optional
from llama_models.datatypes import (
BuiltinTool,
)
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolDefinition, ToolDefinition,
ToolParamDefinition, ToolParamDefinition,
) )

View file

@ -11,7 +11,7 @@
# top-level folder for each specific model found within the models/ directory at # top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree. # the top-level of this source tree.
from llama_models.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
StopReason, StopReason,
ToolCall, ToolCall,

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from logging import getLogger
from pathlib import Path
from typing import (
AbstractSet,
Collection,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)
import tiktoken
from tiktoken.load import load_tiktoken_bpe
logger = getLogger(__name__)
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
# https://github.com/openai/tiktoken/issues/195
# Here we iterate over subsequences and split if we exceed the limit
# of max consecutive non-whitespace or whitespace characters.
MAX_NO_WHITESPACES_CHARS = 25_000
_INSTANCE = None
class Tokenizer:
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
num_reserved_special_tokens = 256
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
@classmethod
def get_instance(cls):
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
return _INSTANCE
def __init__(self, model_path: str):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path
mergeable_ranks = load_tiktoken_bpe(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|finetune_right_pad_id|>",
"<|step_id|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|eom_id|>", # end of message
"<|eot_id|>", # end of turn
"<|python_tag|>",
"<|image|>",
]
reserved_tokens = [
f"<|reserved_special_token_{2 + i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
]
special_tokens = special_tokens + reserved_tokens
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
self.n_words: int = num_base_tokens + len(special_tokens)
# BOS / EOS token IDs
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
self.eot_id: int = self.special_tokens["<|eot_id|>"]
self.eom_id: int = self.special_tokens["<|eom_id|>"]
self.python_tag_id = self.special_tokens["<|python_tag|>"]
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
self.stop_tokens = [
self.eos_id,
self.special_tokens["<|eom_id|>"],
self.special_tokens["<|eot_id|>"],
]
def encode(
self,
s: str,
*,
bos: bool,
eos: bool,
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
"""
Encodes a string into a list of token IDs.
Args:
s (str): The input string to be encoded.
bos (bool): Whether to prepend the beginning-of-sequence token.
eos (bool): Whether to append the end-of-sequence token.
allowed_special ("all"|set[str]): allowed special tokens in string
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
Returns:
list[int]: A list of token IDs.
By default, setting disallowed_special=() encodes a string by ignoring
special tokens. Specifically:
- Setting `disallowed_special` to () will cause all text corresponding
to special tokens to be encoded as natural text (insteading of raising
an error).
- Setting `allowed_special` to "all" will treat all text corresponding
to special tokens to be encoded as special tokens.
"""
if allowed_special is None:
allowed_special = set()
assert type(s) is str
substrs = (
substr
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
for substr in self._split_whitespaces_or_nonwhitespaces(
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
for substr in substrs:
t.extend(
self.model.encode(
substr,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if bos:
t.insert(0, self.bos_id)
if eos:
t.append(self.eos_id)
return t
def decode(self, t: Sequence[int]) -> str:
"""
Decodes a list of token IDs into a string.
Args:
t (List[int]): The list of token IDs to be decoded.
Returns:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
"""
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
consecutive whitespaces or consecutive non-whitespaces.
"""
current_slice_len = 0
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
slice_start = 0
for i in range(len(s)):
is_now_space = s[i].isspace()
if current_slice_is_space ^ is_now_space:
current_slice_len = 1
current_slice_is_space = is_now_space
else:
current_slice_len += 1
if current_slice_len > max_consecutive_slice_len:
yield s[slice_start:i]
slice_start = i
current_slice_len = 1
yield s[slice_start:]

View file

@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import ast
import json
import re
from typing import Optional, Tuple
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
def is_json(s):
try:
parsed = json.loads(s)
# Return True for valid objects and not for ints, strings, etc
return isinstance(parsed, dict)
except json.JSONDecodeError:
return False
return True
def is_valid_python_list(input_string):
"""Check if the input string is a valid Python list of function calls"""
try:
# Try to parse the string
tree = ast.parse(input_string)
# Check if it's a single expression
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
return False
# Check if the expression is a list
expr = tree.body[0].value
if not isinstance(expr, ast.List):
return False
# Check if the list is empty
if len(expr.elts) == 0:
return False
# Check if all elements in the list are function calls
for element in expr.elts:
if not isinstance(element, ast.Call):
return False
# Check if the function call has a valid name
if not isinstance(element.func, ast.Name):
return False
# Check if all arguments are keyword arguments
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
return False
return True
except SyntaxError:
# If parsing fails, it's not a valid Python expression
return False
def parse_python_list_for_function_calls(input_string):
"""
Parse a Python list of function calls and
return a list of tuples containing the function name and arguments
"""
# Parse the string into an AST
tree = ast.parse(input_string)
# Ensure the input is a list
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
raise ValueError("Input must be a list of function calls")
result = []
# Iterate through each function call in the list
for node in tree.body[0].value.elts:
if isinstance(node, ast.Call):
function_name = node.func.id
function_args = {}
# Extract keyword arguments
for keyword in node.keywords:
function_args[keyword.arg] = ast.literal_eval(keyword.value)
result.append((function_name, function_args))
return result
class ToolUtils:
@staticmethod
def is_builtin_tool_call(message_body: str) -> bool:
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
return match is not None
@staticmethod
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
# Find the first match in the text
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
# Check if a match is found and return it
if match:
tool_name = match.group("tool_name")
query = match.group("query")
return tool_name, query
else:
return None
@staticmethod
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
# NOTE: Custom function too calls are still experimental
# Sometimes, response is of the form
# {"type": "function", "name": "function_name", "parameters": {...}
# and some times
# <function=function_name>(parameters)</function>
# Find the first match in the text
match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
if match:
tool_name = match.group("function_name")
query = match.group("args")
try:
return tool_name, json.loads(query.replace("'", '"'))
except Exception as e:
print("Exception while parsing json query for custom tool call", query, e)
return None
elif is_json(message_body):
response = json.loads(message_body)
if ("type" in response and response["type"] == "function") or ("name" in response):
function_name = response["name"]
args = response["parameters"]
return function_name, args
else:
return None
elif is_valid_python_list(message_body):
res = parse_python_list_for_function_calls(message_body)
# FIXME: Enable multiple tool calls
return res[0]
else:
return None
@staticmethod
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
if t.tool_name == BuiltinTool.brave_search:
q = t.arguments["query"]
return f'brave_search.call(query="{q}")'
elif t.tool_name == BuiltinTool.wolfram_alpha:
q = t.arguments["query"]
return f'wolfram_alpha.call(query="{q}")'
elif t.tool_name == BuiltinTool.photogen:
q = t.arguments["query"]
return f'photogen.call(query="{q}")'
elif t.tool_name == BuiltinTool.code_interpreter:
return t.arguments["code"]
else:
fname = t.tool_name
if tool_prompt_format == ToolPromptFormat.json:
return json.dumps(
{
"type": "function",
"name": fname,
"parameters": t.arguments,
}
)
elif tool_prompt_format == ToolPromptFormat.function_tag:
args = json.dumps(t.arguments)
return f"<function={fname}>{args}</function>"
elif tool_prompt_format == ToolPromptFormat.python_list:
def format_value(value: RecursiveType) -> str:
if isinstance(value, str):
return f'"{value}"'
elif isinstance(value, (int, float, bool)) or value is None:
return str(value)
elif isinstance(value, list):
return f"[{', '.join(format_value(v) for v in value)}]"
elif isinstance(value, dict):
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
else:
raise ValueError(f"Unsupported type: {type(value)}")
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
return f"[{fname}({args_str})]"
else:
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")

View file

@ -0,0 +1,358 @@
# Llama 3.1 - Prompt Formats
## Tokens
Here is a list of special tokens that are supported by Llama 3.1:
- `<|begin_of_text|>`: Specifies the start of the prompt
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and ipython]
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
- at the end of a direct interaction between the model and the user
- at the end of multiple interactions between the model and any available tools
This token signals to the executor that the model has finished generating a response.
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
There are 4 different roles that are supported by Llama 3.1
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
- `ipython`: A new role introduced in Llama 3.1. Semantically, this role means "tool". This role is used to mark messages with the output of a tool call when sent back to the model from the executor.
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `ipython` and `user` prompts.
## Llama 3.1 Base Model
Text completion for Llama 3.1 base model uses this format.
##### Input Prompt Format
```
<|begin_of_text|>Color of sky is blue but sometimes can also be
```
##### Model Response Format
```
red, orange, yellow, green, purple, pink, brown, gray, black, white, and even rainbow colors. The color of the sky can change due to various reasons such as time of day, weather conditions, pollution, and atmospheric phenomena.
The color of the sky is primarily blue because of a phenomenon called
```
Note start special tag
## Llama 3.1 Instruct Model
## User and assistant conversation
Here is a regular multi-turn user assistant conversation and how its formatted.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
Answer who are you in the form of jeopardy?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
Here's my response
"What is a helpful assistant?"<|eot_id|>
```
## Tool Calling Formats
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
- Brave Search: Tool call to perform web searches.
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
- Code Interpreter: Enables the model to output python code.
## Builtin Tool Calling
Here is an example of a conversation using brave search
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Tools: brave_search, wolfram_alpha
Cutting Knowledge Date: December 2023
Today Date: 21 September 2024
You are a helpful assistant.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
```
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
- The message body of the assistant response starts with a special tag <|python_tag|>
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
## Builtin Code Interpreter
Here is an actual example of model responding with code
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython<|eot_id|><|start_header_id|>user<|end_header_id|>
Write code to check if number is prime, use that to see if the number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>def is_prime(n):
if n <= 1
return False
for i in range(2, int(n**0.5) + 1):
if n % i == 0:
return False
return True
print(is_prime(7)) # Output: True<|eom_id|>
```
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
## Built-in tools full interaction
Here is a full interaction with the built-in tools including the tool response and the final assistant response.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Tools: brave_search, wolfram_alpha
<|eot_id|><|start_header_id|>user<|end_header_id|>
What is the 100th decimal of pi?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
<|python_tag|>wolfram_alpha.call(query="100th decimal of pi")<|eom_id|><|start_header_id|>ipython<|end_header_id|>
{
"queryresult": {
"success": true,
"inputstring": "100th decimal of pi",
"pods": [
{
"title": "Input interpretation",
"subpods": [
{
"title": "",
"plaintext": "100th digit | π"
}
]
},
{
"title": "Nearby digits",
"subpods": [
{
"title": "",
"plaintext": "...86208998628034825342117067982148086513282306647093..."
}
]
},
{
"title": "Result",
"primary": true,
"subpods": [
{
"title": "",
"plaintext": "7"
}
]
}
]
}
}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
The 100th decimal of pi is 7.<|eot_id|>
```
- Note the `<|python_tag|>` in the assistant response.
- Role is `ipython` for the wolfram alpha response that is passed back to the model.
- Final message from assistant has <|eot_id|> tag.
## Zero shot tool calling
## JSON based tool calling
Llama models can now output custom tool calls from a single message to allow easier tool calling.
The following prompts provide an example of how custom tools can be called from the output of the model.
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Cutting Knowledge Date: December 2023
Today Date: 21 September 2024
You are a helpful assistant.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{
"type": "function",
"function": {
"name": "trending_songs",
"description": "Returns the trending songs on a Music site",
"parameters": {
"type": "object",
"properties": [
{
"n": {
"type": "object",
"description": "The number of songs to return"
}
},
{
"genre": {
"type": "object",
"description": "The genre of the songs to return"
}
}
],
"required": ["n"]
}
}
}
Return function calls in JSON format.<|eot_id|><|start_header_id|>user<|end_header_id|>
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>{
"type": "function",
"name": "trending_songs",
"parameters": {
"n": "10",
"genre": "all"
}
}<|eom_id|>
```
- JSON format for providing tools needs name, description and parameters
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
- Instructions for tools added as a user message
- Only single tool calls are supported as of now
## Example of a user defined tool calling
## `<function>` based tool calling
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
In this example, we define a custom tool calling format using the `<function>` tag.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Cutting Knowledge Date: December 2023
Today Date: 21 September 2024
You are a helpful assistant.
<|eot_id|><|start_header_id|>user<|end_header_id|>
You have access to the following functions:
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
Think very carefully before calling functions.
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- If looking for real time information use relevant functions before falling back to brave_search
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line<|eot_id|><|start_header_id|>user<|end_header_id|>
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<function=trending_songs>{"n": 10}</function><|eot_id|>
```
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
- Instructions for tools added as a user message
Thank You!

View file

@ -14,7 +14,7 @@
import textwrap import textwrap
from typing import List from typing import List
from llama_models.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
RawMessage, RawMessage,
StopReason, StopReason,

Some files were not shown because too many files have changed in this diff Show more