Merge branch 'meta-llama:main' into fix-ollama-rag

This commit is contained in:
Kai Wu 2025-03-03 10:40:14 -08:00 committed by GitHub
commit 2868d8f793
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
215 changed files with 137658 additions and 2561 deletions

View file

@ -8,6 +8,8 @@ repos:
rev: v5.0.0 # Latest stable version rev: v5.0.0 # Latest stable version
hooks: hooks:
- id: check-merge-conflict - id: check-merge-conflict
- id: trailing-whitespace
exclude: '\.py$' # Exclude Python files as Ruff already handles them
- id: check-added-large-files - id: check-added-large-files
args: ['--maxkb=1000'] args: ['--maxkb=1000']
- id: end-of-file-fixer - id: end-of-file-fixer
@ -83,10 +85,8 @@ repos:
- id: distro-codegen - id: distro-codegen
name: Distribution Template Codegen name: Distribution Template Codegen
additional_dependencies: additional_dependencies:
- rich
- pydantic
- uv==0.6.0 - uv==0.6.0
entry: uv run python -m llama_stack.scripts.distro_codegen entry: uv run --extra codegen python -m llama_stack.scripts.distro_codegen
language: python language: python
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true

View file

@ -70,6 +70,19 @@ $ uv pip install -e .
$ source .venv/bin/activate $ source .venv/bin/activate
``` ```
Note that you can create a dotenv file `.env` that includes necessary environment variables:
```
LLAMA_STACK_BASE_URL=http://localhost:8321
LLAMA_STACK_CLIENT_LOG=debug
LLAMA_STACK_PORT=8321
LLAMA_STACK_CONFIG=
```
And then use this dotenv file when running client SDK tests via the following:
```bash
$ uv run --env-file .env -- pytest -v tests/client-sdk/inference/test_text_inference.py
```
## Pre-commit Hooks ## Pre-commit Hooks
We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running: We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running:
@ -110,15 +123,15 @@ Some tips about common tasks you work on while contributing to Llama Stack:
### Using `llama stack build` ### Using `llama stack build`
Building a stack image (conda / docker) will use the production version of the `llama-stack`, `llama-models` and `llama-stack-client` packages. If you are developing with a llama-stack repository checked out and need your code to be reflected in the stack image, set `LLAMA_STACK_DIR` and `LLAMA_MODELS_DIR` to the appropriate checked out directories when running any of the `llama` CLI commands. Building a stack image (conda / docker) will use the production version of the `llama-stack` and `llama-stack-client` packages. If you are developing with a llama-stack repository checked out and need your code to be reflected in the stack image, set `LLAMA_STACK_DIR` and `LLAMA_STACK_CLIENT_DIR` to the appropriate checked out directories when running any of the `llama` CLI commands.
Example: Example:
```bash ```bash
$ cd work/ $ cd work/
$ git clone https://github.com/meta-llama/llama-stack.git $ git clone https://github.com/meta-llama/llama-stack.git
$ git clone https://github.com/meta-llama/llama-models.git $ git clone https://github.com/meta-llama/llama-stack-client-python.git
$ cd llama-stack $ cd llama-stack
$ LLAMA_STACK_DIR=$(pwd) LLAMA_MODELS_DIR=../llama-models llama stack build --template <...> $ LLAMA_STACK_DIR=$(pwd) LLAMA_STACK_CLIENT_DIR=../llama-stack-client-python llama stack build --template <...>
``` ```

View file

@ -32,7 +32,7 @@ Llama Stack standardizes the core building blocks that simplify AI application d
By reducing friction and complexity, Llama Stack empowers developers to focus on what they do best: building transformative generative AI applications. By reducing friction and complexity, Llama Stack empowers developers to focus on what they do best: building transformative generative AI applications.
### API Providers ### API Providers
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack. Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:| |:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|

View file

@ -216,8 +216,8 @@
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fire",
"groq",
"httpx", "httpx",
"litellm",
"matplotlib", "matplotlib",
"nltk", "nltk",
"numpy", "numpy",
@ -431,6 +431,7 @@
"fire", "fire",
"httpx", "httpx",
"matplotlib", "matplotlib",
"mcp",
"nltk", "nltk",
"numpy", "numpy",
"ollama", "ollama",

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long

View file

@ -45,65 +45,7 @@
"id": "O9pGVlPIjpix", "id": "O9pGVlPIjpix",
"outputId": "e1fbe723-ae31-4630-eb80-4c4f6476d56f" "outputId": "e1fbe723-ae31-4630-eb80-4c4f6476d56f"
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: llama-stack in /usr/local/lib/python3.10/dist-packages (0.0.61)\n",
"Requirement already satisfied: blobfile in /usr/local/lib/python3.10/dist-packages (from llama-stack) (3.0.0)\n",
"Requirement already satisfied: fire in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.7.0)\n",
"Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.28.1)\n",
"Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.26.5)\n",
"Requirement already satisfied: llama-models>=0.0.61 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.0.61)\n",
"Requirement already satisfied: llama-stack-client>=0.0.61 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.0.61)\n",
"Requirement already satisfied: prompt-toolkit in /usr/local/lib/python3.10/dist-packages (from llama-stack) (3.0.48)\n",
"Requirement already satisfied: python-dotenv in /usr/local/lib/python3.10/dist-packages (from llama-stack) (1.0.1)\n",
"Requirement already satisfied: pydantic>=2 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.10.3)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.32.3)\n",
"Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from llama-stack) (13.9.4)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from llama-stack) (75.1.0)\n",
"Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.5.0)\n",
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (6.0.2)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (3.1.4)\n",
"Requirement already satisfied: tiktoken in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (0.8.0)\n",
"Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (10.4.0)\n",
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (3.7.1)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (8.1.7)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (1.9.0)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (2.2.2)\n",
"Requirement already satisfied: pyaml in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (24.12.1)\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (1.3.1)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (4.66.6)\n",
"Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (4.12.2)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (2024.8.30)\n",
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (1.0.7)\n",
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (3.10)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx->llama-stack) (0.14.0)\n",
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2->llama-stack) (0.7.0)\n",
"Requirement already satisfied: pydantic-core==2.27.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2->llama-stack) (2.27.1)\n",
"Requirement already satisfied: pycryptodomex>=3.8 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (3.21.0)\n",
"Requirement already satisfied: urllib3<3,>=1.25.3 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (2.2.3)\n",
"Requirement already satisfied: lxml>=4.9 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (5.3.0)\n",
"Requirement already satisfied: filelock>=3.0 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (3.16.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->llama-stack) (2024.9.0)\n",
"Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->llama-stack) (24.2)\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit->llama-stack) (0.2.13)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->llama-stack) (3.4.0)\n",
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->llama-stack) (3.0.0)\n",
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->llama-stack) (2.18.0)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->llama-stack-client>=0.0.61->llama-stack) (1.2.2)\n",
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->llama-stack) (0.1.2)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->llama-models>=0.0.61->llama-stack) (3.0.2)\n",
"Requirement already satisfied: numpy>=1.22.4 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (1.26.4)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2024.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2024.2)\n",
"Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.10/dist-packages (from tiktoken->llama-models>=0.0.61->llama-stack) (2024.9.11)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->llama-stack-client>=0.0.61->llama-stack) (1.17.0)\n"
]
}
],
"source": [ "source": [
"# NBVAL_SKIP\n", "# NBVAL_SKIP\n",
"!pip install -U llama-stack" "!pip install -U llama-stack"
@ -120,198 +62,10 @@
"id": "JQpLUSNjlGAM", "id": "JQpLUSNjlGAM",
"outputId": "2f7fec97-5511-4cae-d51e-6d262fbca19c" "outputId": "2f7fec97-5511-4cae-d51e-6d262fbca19c"
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: llama-stack in /usr/local/lib/python3.10/dist-packages (0.0.61)\r\n",
"Requirement already satisfied: blobfile in /usr/local/lib/python3.10/dist-packages (from llama-stack) (3.0.0)\r\n",
"Requirement already satisfied: fire in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.7.0)\r\n",
"Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.28.1)\r\n",
"Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.26.5)\r\n",
"Requirement already satisfied: llama-models>=0.0.61 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.0.61)\r\n",
"Requirement already satisfied: llama-stack-client>=0.0.61 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.0.61)\r\n",
"Requirement already satisfied: prompt-toolkit in /usr/local/lib/python3.10/dist-packages (from llama-stack) (3.0.48)\r\n",
"Requirement already satisfied: python-dotenv in /usr/local/lib/python3.10/dist-packages (from llama-stack) (1.0.1)\r\n",
"Requirement already satisfied: pydantic>=2 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.10.3)\r\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.32.3)\r\n",
"Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from llama-stack) (13.9.4)\r\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from llama-stack) (75.1.0)\r\n",
"Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.5.0)\r\n",
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (6.0.2)\r\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (3.1.4)\r\n",
"Requirement already satisfied: tiktoken in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (0.8.0)\r\n",
"Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (10.4.0)\r\n",
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (3.7.1)\r\n",
"Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (8.1.7)\r\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (1.9.0)\r\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (2.2.2)\r\n",
"Requirement already satisfied: pyaml in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (24.12.1)\r\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (1.3.1)\r\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (4.66.6)\r\n",
"Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (4.12.2)\r\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (2024.8.30)\r\n",
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (1.0.7)\r\n",
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (3.10)\r\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx->llama-stack) (0.14.0)\r\n",
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2->llama-stack) (0.7.0)\r\n",
"Requirement already satisfied: pydantic-core==2.27.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2->llama-stack) (2.27.1)\r\n",
"Requirement already satisfied: pycryptodomex>=3.8 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (3.21.0)\r\n",
"Requirement already satisfied: urllib3<3,>=1.25.3 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (2.2.3)\r\n",
"Requirement already satisfied: lxml>=4.9 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (5.3.0)\r\n",
"Requirement already satisfied: filelock>=3.0 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (3.16.1)\r\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->llama-stack) (2024.9.0)\r\n",
"Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->llama-stack) (24.2)\r\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit->llama-stack) (0.2.13)\r\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->llama-stack) (3.4.0)\r\n",
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->llama-stack) (3.0.0)\r\n",
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->llama-stack) (2.18.0)\r\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->llama-stack-client>=0.0.61->llama-stack) (1.2.2)\n",
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->llama-stack) (0.1.2)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->llama-models>=0.0.61->llama-stack) (3.0.2)\n",
"Requirement already satisfied: numpy>=1.22.4 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (1.26.4)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2024.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2024.2)\n",
"Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.10/dist-packages (from tiktoken->llama-models>=0.0.61->llama-stack) (2024.9.11)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->llama-stack-client>=0.0.61->llama-stack) (1.17.0)\n",
"Installing pip dependencies\n",
"Requirement already satisfied: blobfile in /usr/local/lib/python3.10/dist-packages (3.0.0)\n",
"Requirement already satisfied: chardet in /usr/local/lib/python3.10/dist-packages (5.2.0)\n",
"Requirement already satisfied: opentelemetry-sdk in /usr/local/lib/python3.10/dist-packages (1.28.2)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.13.1)\n",
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (2.2.2)\n",
"Requirement already satisfied: autoevals in /usr/local/lib/python3.10/dist-packages (0.0.109)\n",
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.2.0)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.5.2)\n",
"Requirement already satisfied: pillow in /usr/local/lib/python3.10/dist-packages (10.4.0)\n",
"Requirement already satisfied: pypdf in /usr/local/lib/python3.10/dist-packages (5.1.0)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.6)\n",
"Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (3.9.1)\n",
"Requirement already satisfied: aiosqlite in /usr/local/lib/python3.10/dist-packages (0.20.0)\n",
"Requirement already satisfied: psycopg2-binary in /usr/local/lib/python3.10/dist-packages (2.9.10)\n",
"Requirement already satisfied: faiss-cpu in /usr/local/lib/python3.10/dist-packages (1.9.0.post1)\n",
"Requirement already satisfied: opentelemetry-exporter-otlp-proto-http in /usr/local/lib/python3.10/dist-packages (1.28.2)\n",
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.46.3)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.26.4)\n",
"Requirement already satisfied: chromadb-client in /usr/local/lib/python3.10/dist-packages (0.5.23)\n",
"Requirement already satisfied: openai in /usr/local/lib/python3.10/dist-packages (1.54.5)\n",
"Requirement already satisfied: redis in /usr/local/lib/python3.10/dist-packages (5.2.1)\n",
"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.2.0)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.8.0)\n",
"Requirement already satisfied: together in /usr/local/lib/python3.10/dist-packages (1.3.5)\n",
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (0.115.6)\n",
"Requirement already satisfied: fire in /usr/local/lib/python3.10/dist-packages (0.7.0)\n",
"Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (0.28.1)\n",
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (0.32.1)\n",
"Requirement already satisfied: pycryptodomex>=3.8 in /usr/local/lib/python3.10/dist-packages (from blobfile) (3.21.0)\n",
"Requirement already satisfied: urllib3<3,>=1.25.3 in /usr/local/lib/python3.10/dist-packages (from blobfile) (2.2.3)\n",
"Requirement already satisfied: lxml>=4.9 in /usr/local/lib/python3.10/dist-packages (from blobfile) (5.3.0)\n",
"Requirement already satisfied: filelock>=3.0 in /usr/local/lib/python3.10/dist-packages (from blobfile) (3.16.1)\n",
"Requirement already satisfied: opentelemetry-api==1.28.2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-sdk) (1.28.2)\n",
"Requirement already satisfied: opentelemetry-semantic-conventions==0.49b2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-sdk) (0.49b2)\n",
"Requirement already satisfied: typing-extensions>=3.7.4 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-sdk) (4.12.2)\n",
"Requirement already satisfied: deprecated>=1.2.6 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-api==1.28.2->opentelemetry-sdk) (1.2.15)\n",
"Requirement already satisfied: importlib-metadata<=8.5.0,>=6.0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-api==1.28.2->opentelemetry-sdk) (8.5.0)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2024.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas) (2024.2)\n",
"Requirement already satisfied: chevron in /usr/local/lib/python3.10/dist-packages (from autoevals) (0.14.0)\n",
"Requirement already satisfied: levenshtein in /usr/local/lib/python3.10/dist-packages (from autoevals) (0.26.1)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from autoevals) (6.0.2)\n",
"Requirement already satisfied: braintrust_core==0.0.54 in /usr/local/lib/python3.10/dist-packages (from autoevals) (0.0.54)\n",
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from autoevals) (4.23.0)\n",
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.4.2)\n",
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.5.0)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk) (8.1.7)\n",
"Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk) (2024.9.11)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from faiss-cpu) (24.2)\n",
"Requirement already satisfied: googleapis-common-protos~=1.52 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-http) (1.66.0)\n",
"Requirement already satisfied: opentelemetry-exporter-otlp-proto-common==1.28.2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-http) (1.28.2)\n",
"Requirement already satisfied: opentelemetry-proto==1.28.2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-http) (1.28.2)\n",
"Requirement already satisfied: requests~=2.7 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-http) (2.32.3)\n",
"Requirement already satisfied: protobuf<6.0,>=5.0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-proto==1.28.2->opentelemetry-exporter-otlp-proto-http) (5.29.1)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.26.5)\n",
"Requirement already satisfied: tokenizers<0.21,>=0.20 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.20.3)\n",
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n",
"Requirement already satisfied: opentelemetry-exporter-otlp-proto-grpc>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (1.28.2)\n",
"Requirement already satisfied: overrides>=7.3.1 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (7.7.0)\n",
"Requirement already satisfied: posthog>=2.4.0 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (3.7.4)\n",
"Requirement already satisfied: pydantic>=1.9 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (2.10.3)\n",
"Requirement already satisfied: tenacity>=8.2.3 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (9.0.0)\n",
"Requirement already satisfied: orjson>=3.9.12 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (3.10.12)\n",
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai) (3.7.1)\n",
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from openai) (1.9.0)\n",
"Requirement already satisfied: jiter<1,>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from openai) (0.8.2)\n",
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai) (1.3.1)\n",
"Requirement already satisfied: async-timeout>=4.0.3 in /usr/local/lib/python3.10/dist-packages (from redis) (4.0.3)\n",
"Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n",
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n",
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n",
"Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n",
"Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.10)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.3.1)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.55.2)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.7)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (3.2.0)\n",
"Requirement already satisfied: eval-type-backport<0.3.0,>=0.1.3 in /usr/local/lib/python3.10/dist-packages (from together) (0.2.0)\n",
"Requirement already satisfied: rich<14.0.0,>=13.8.1 in /usr/local/lib/python3.10/dist-packages (from together) (13.9.4)\n",
"Requirement already satisfied: tabulate<0.10.0,>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from together) (0.9.0)\n",
"Requirement already satisfied: typer<0.14,>=0.9 in /usr/local/lib/python3.10/dist-packages (from together) (0.13.1)\n",
"Requirement already satisfied: starlette<0.42.0,>=0.40.0 in /usr/local/lib/python3.10/dist-packages (from fastapi) (0.41.3)\n",
"Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from fire) (2.5.0)\n",
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx) (2024.8.30)\n",
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx) (1.0.7)\n",
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx) (3.10)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx) (0.14.0)\n",
"Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.4)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n",
"Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.1)\n",
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.18.3)\n",
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (1.2.2)\n",
"Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.10/dist-packages (from deprecated>=1.2.6->opentelemetry-api==1.28.2->opentelemetry-sdk) (1.17.0)\n",
"Requirement already satisfied: grpcio<2.0.0,>=1.63.2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb-client) (1.68.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from posthog>=2.4.0->chromadb-client) (1.17.0)\n",
"Requirement already satisfied: monotonic>=1.5 in /usr/local/lib/python3.10/dist-packages (from posthog>=2.4.0->chromadb-client) (1.6)\n",
"Requirement already satisfied: backoff>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from posthog>=2.4.0->chromadb-client) (2.2.1)\n",
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.9->chromadb-client) (0.7.0)\n",
"Requirement already satisfied: pydantic-core==2.27.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.9->chromadb-client) (2.27.1)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests~=2.7->opentelemetry-exporter-otlp-proto-http) (3.4.0)\n",
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich<14.0.0,>=13.8.1->together) (3.0.0)\n",
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich<14.0.0,>=13.8.1->together) (2.18.0)\n",
"Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from typer<0.14,>=0.9->together) (1.5.4)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->autoevals) (2024.10.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->autoevals) (0.35.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->autoevals) (0.22.3)\n",
"Requirement already satisfied: rapidfuzz<4.0.0,>=3.9.0 in /usr/local/lib/python3.10/dist-packages (from levenshtein->autoevals) (3.10.1)\n",
"Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.10/dist-packages (from importlib-metadata<=8.5.0,>=6.0->opentelemetry-api==1.28.2->opentelemetry-sdk) (3.21.0)\n",
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich<14.0.0,>=13.8.1->together) (0.1.2)\n",
"sentence-transformers --no-deps\n",
"Requirement already satisfied: sentence-transformers in /usr/local/lib/python3.10/dist-packages (3.2.1)\n",
"torch --index-url https://download.pytorch.org/whl/cpu\n",
"Looking in indexes: https://download.pytorch.org/whl/cpu\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.9.0)\n",
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)\n",
"\u001b[32mBuild Successful!\u001b[0m\n"
]
}
],
"source": [ "source": [
"# NBVAL_SKIP\n", "# NBVAL_SKIP\n",
"!llama stack build --template together --image-type venv --image-name __system__" "!UV_SYSTEM_PYTHON=1 llama stack build --template together --image-type venv"
] ]
}, },
{ {

View file

@ -3,7 +3,7 @@ The RFC Specification (OpenAPI format) is generated from the set of API endpoint
Please install the following packages before running the script: Please install the following packages before running the script:
``` ```
pip install fire PyYAML llama-models pip install fire PyYAML
``` ```
Then simply run `sh run_openapi_generator.sh` Then simply run `sh run_openapi_generator.sh`

View file

@ -55,6 +55,7 @@ def main(output_dir: str):
a set of endpoints and their corresponding interfaces that are tailored to a set of endpoints and their corresponding interfaces that are tailored to
best leverage Llama Models.""", best leverage Llama Models.""",
), ),
include_standard_error_responses=True,
), ),
) )

View file

@ -10,6 +10,7 @@ import typing
from dataclasses import make_dataclass from dataclasses import make_dataclass
from typing import Any, Dict, Set, Union from typing import Any, Dict, Set, Union
from llama_stack.apis.datatypes import Error
from llama_stack.strong_typing.core import JsonType from llama_stack.strong_typing.core import JsonType
from llama_stack.strong_typing.docstring import Docstring, parse_type from llama_stack.strong_typing.docstring import Docstring, parse_type
from llama_stack.strong_typing.inspection import ( from llama_stack.strong_typing.inspection import (
@ -434,6 +435,75 @@ class Generator:
) )
self.schema_builder = SchemaBuilder(schema_generator) self.schema_builder = SchemaBuilder(schema_generator)
self.responses = {} self.responses = {}
# Create standard error responses
self._create_standard_error_responses()
def _create_standard_error_responses(self) -> None:
"""
Creates standard error responses that can be reused across operations.
These will be added to the components.responses section of the OpenAPI document.
"""
# Get the Error schema
error_schema = self.schema_builder.classdef_to_ref(Error)
# Create standard error responses
self.responses["BadRequest400"] = Response(
description="The request was invalid or malformed",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 400,
"title": "Bad Request",
"detail": "The request was invalid or malformed",
}
)
}
)
self.responses["TooManyRequests429"] = Response(
description="The client has sent too many requests in a given amount of time",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 429,
"title": "Too Many Requests",
"detail": "You have exceeded the rate limit. Please try again later.",
}
)
}
)
self.responses["InternalServerError500"] = Response(
description="The server encountered an unexpected error",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 500,
"title": "Internal Server Error",
"detail": "An unexpected error occurred. Our team has been notified.",
}
)
}
)
# Add a default error response for any unhandled error cases
self.responses["DefaultError"] = Response(
description="An unexpected error occurred",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 0,
"title": "Error",
"detail": "An unexpected error occurred",
}
)
}
)
def _build_type_tag(self, ref: str, schema: Schema) -> Tag: def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
# Don't include schema definition in the tag description because for one, # Don't include schema definition in the tag description because for one,
@ -649,6 +719,18 @@ class Generator:
responses.update(response_builder.build_response(response_options)) responses.update(response_builder.build_response(response_options))
assert len(responses.keys()) > 0, f"No responses found for {op.name}" assert len(responses.keys()) > 0, f"No responses found for {op.name}"
# Add standard error response references
if self.options.include_standard_error_responses:
if "400" not in responses:
responses["400"] = ResponseRef("BadRequest400")
if "429" not in responses:
responses["429"] = ResponseRef("TooManyRequests429")
if "500" not in responses:
responses["500"] = ResponseRef("InternalServerError500")
if "default" not in responses:
responses["default"] = ResponseRef("DefaultError")
if op.event_type is not None: if op.event_type is not None:
builder = ContentBuilder(self.schema_builder) builder = ContentBuilder(self.schema_builder)
callbacks = { callbacks = {

View file

@ -35,6 +35,7 @@ class Options:
:param error_wrapper: True if errors are encapsulated in an error object wrapper. :param error_wrapper: True if errors are encapsulated in an error object wrapper.
:param property_description_fun: Custom transformation function to apply to class property documentation strings. :param property_description_fun: Custom transformation function to apply to class property documentation strings.
:param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types. :param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types.
:param include_standard_error_responses: Whether to include standard error responses (400, 429, 500, 503) in all operations.
""" """
server: Server server: Server
@ -52,6 +53,7 @@ class Options:
error_wrapper: bool = False error_wrapper: bool = False
property_description_fun: Optional[Callable[[type, str, str], str]] = None property_description_fun: Optional[Callable[[type, str, str], str]] = None
captions: Optional[Dict[str, str]] = None captions: Optional[Dict[str, str]] = None
include_standard_error_responses: bool = True
default_captions: ClassVar[Dict[str, str]] = { default_captions: ClassVar[Dict[str, str]] = {
"Operations": "Operations", "Operations": "Operations",

View file

@ -28,6 +28,5 @@ if [ ${#missing_packages[@]} -ne 0 ]; then
fi fi
stack_dir=$(dirname $(dirname $THIS_DIR)) stack_dir=$(dirname $(dirname $THIS_DIR))
models_dir=$(dirname $stack_dir)/llama-models PYTHONPATH=$PYTHONPATH:$stack_dir \
PYTHONPATH=$PYTHONPATH:$stack_dir:$models_dir \
python -m docs.openapi_generator.generate $(dirname $THIS_DIR)/_static python -m docs.openapi_generator.generate $(dirname $THIS_DIR)/_static

View file

@ -0,0 +1,91 @@
# Llama Stack Agent Framework
The Llama Stack agent framework is built on a modular architecture that allows for flexible and powerful AI applications. This document explains the key components and how they work together.
## Core Concepts
### 1. Agent Configuration
Agents are configured using the `AgentConfig` class, which includes:
- **Model**: The underlying LLM to power the agent
- **Instructions**: System prompt that defines the agent's behavior
- **Tools**: Capabilities the agent can use to interact with external systems
- **Safety Shields**: Guardrails to ensure responsible AI behavior
```python
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.lib.agents.agent import Agent
# Configure an agent
agent_config = AgentConfig(
model="meta-llama/Llama-3-70b-chat",
instructions="You are a helpful assistant that can use tools to answer questions.",
toolgroups=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
)
# Create the agent
agent = Agent(llama_stack_client, agent_config)
```
### 2. Sessions
Agents maintain state through sessions, which represent a conversation thread:
```python
# Create a session
session_id = agent.create_session(session_name="My conversation")
```
### 3. Turns
Each interaction with an agent is called a "turn" and consists of:
- **Input Messages**: What the user sends to the agent
- **Steps**: The agent's internal processing (inference, tool execution, etc.)
- **Output Message**: The agent's response
```python
from llama_stack_client.lib.agents.event_logger import EventLogger
# Create a turn with streaming response
turn_response = agent.create_turn(
session_id=session_id,
messages=[{"role": "user", "content": "Tell me about Llama models"}],
)
for log in EventLogger().log(turn_response):
log.print()
```
### Non-Streaming
```python
from rich.pretty import pprint
# Non-streaming API
response = agent.create_turn(
session_id=session_id,
messages=[{"role": "user", "content": "Tell me about Llama models"}],
stream=False,
)
print("Inputs:")
pprint(response.input_messages)
print("Output:")
pprint(response.output_message.content)
print("Steps:")
pprint(response.steps)
```
### 4. Steps
Each turn consists of multiple steps that represent the agent's thought process:
- **Inference Steps**: The agent generating text responses
- **Tool Execution Steps**: The agent using tools to gather information
- **Shield Call Steps**: Safety checks being performed
## Agent Execution Loop
Refer to the [Agent Execution Loop](agent_execution_loop) for more details on what happens within an agent turn.

View file

@ -13,7 +13,7 @@ Each agent turn follows these key steps:
3. **Inference Loop**: The agent enters its main execution loop: 3. **Inference Loop**: The agent enters its main execution loop:
- The LLM receives a user prompt (with previous tool outputs) - The LLM receives a user prompt (with previous tool outputs)
- The LLM generates a response, potentially with tool calls - The LLM generates a response, potentially with [tool calls](tools)
- If tool calls are present: - If tool calls are present:
- Tool inputs are safety-checked - Tool inputs are safety-checked
- Tools are executed (e.g., web search, code execution) - Tools are executed (e.g., web search, code execution)
@ -67,9 +67,17 @@ sequenceDiagram
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution: Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
```python ```python
from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from rich.pretty import pprint
# Replace host and port
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
agent_config = AgentConfig( agent_config = AgentConfig(
# Check with `llama-stack-client models list`
model="Llama3.2-3B-Instruct", model="Llama3.2-3B-Instruct",
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
# Enable both RAG and tool usage # Enable both RAG and tool usage
@ -80,7 +88,7 @@ agent_config = AgentConfig(
}, },
"builtin::code_interpreter", "builtin::code_interpreter",
], ],
# Configure safety # Configure safety (optional)
input_shields=["llama_guard"], input_shields=["llama_guard"],
output_shields=["llama_guard"], output_shields=["llama_guard"],
# Control the inference loop # Control the inference loop
@ -97,7 +105,7 @@ session_id = agent.create_session("monitored_session")
# Stream the agent's execution steps # Stream the agent's execution steps
response = agent.create_turn( response = agent.create_turn(
messages=[{"role": "user", "content": "Analyze this code and run it"}], messages=[{"role": "user", "content": "Analyze this code and run it"}],
attachments=[ documents=[
{ {
"content": "https://raw.githubusercontent.com/example/code.py", "content": "https://raw.githubusercontent.com/example/code.py",
"mime_type": "text/plain", "mime_type": "text/plain",
@ -108,14 +116,21 @@ response = agent.create_turn(
# Monitor each step of execution # Monitor each step of execution
for log in EventLogger().log(response): for log in EventLogger().log(response):
if log.event.step_type == "memory_retrieval": log.print()
print("Retrieved context:", log.event.retrieved_context)
elif log.event.step_type == "inference": # Using non-streaming API, the response contains input, steps, and output.
print("LLM output:", log.event.model_response) response = agent.create_turn(
elif log.event.step_type == "tool_execution": messages=[{"role": "user", "content": "Analyze this code and run it"}],
print("Tool call:", log.event.tool_call) documents=[
print("Tool response:", log.event.tool_response) {
elif log.event.step_type == "shield_call": "content": "https://raw.githubusercontent.com/example/code.py",
if log.event.violation: "mime_type": "text/plain",
print("Safety violation:", log.event.violation) }
],
session_id=session_id,
)
pprint(f"Input: {response.input_messages}")
pprint(f"Output: {response.output_message.content}")
pprint(f"Steps: {response.steps}")
``` ```

View file

@ -149,7 +149,6 @@ agent_config = {
} }
], ],
"tool_choice": "auto", "tool_choice": "auto",
"tool_prompt_format": "json",
"input_shields": [], "input_shields": [],
"output_shields": [], "output_shields": [],
"enable_session_persistence": False, "enable_session_persistence": False,

View file

@ -8,22 +8,24 @@ The best way to get started is to look at this notebook which walks through the
Here are some key topics that will help you build effective agents: Here are some key topics that will help you build effective agents:
- **[Agent Execution Loop](agent_execution_loop)** - **[Agent](agent)**: Understand the components and design patterns of the Llama Stack agent framework.
- **[RAG](rag)** - **[Agent Execution Loop](agent_execution_loop)**: Understand how agents process information, make decisions, and execute actions in a continuous loop.
- **[Safety](safety)** - **[RAG (Retrieval-Augmented Generation)](rag)**: Learn how to enhance your agents with external knowledge through retrieval mechanisms.
- **[Tools](tools)** - **[Tools](tools)**: Extend your agents' capabilities by integrating with external tools and APIs.
- **[Telemetry](telemetry)** - **[Evals](evals)**: Evaluate your agents' effectiveness and identify areas for improvement.
- **[Evals](evals)** - **[Telemetry](telemetry)**: Monitor and analyze your agents' performance and behavior.
- **[Safety](safety)**: Implement guardrails and safety measures to ensure responsible AI behavior.
```{toctree} ```{toctree}
:hidden: :hidden:
:maxdepth: 1 :maxdepth: 1
agent
agent_execution_loop agent_execution_loop
rag rag
safety
tools tools
telemetry telemetry
evals evals
advanced_agent_patterns
safety
``` ```

View file

@ -1,8 +1,8 @@
## Using "Memory" or Retrieval Augmented Generation (RAG) ## Using Retrieval Augmented Generation (RAG)
Memory enables your applications to reference and recall information from previous interactions or external documents. RAG enables your applications to reference and recall information from previous interactions or external documents.
Llama Stack organizes the memory APIs into three layers: Llama Stack organizes the APIs that enable RAG into three layers:
- the lowermost APIs deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon.) - the lowermost APIs deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon.)
- next is the "Rag Tool", a first-class tool as part of the Tools API that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly. - next is the "Rag Tool", a first-class tool as part of the Tools API that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly.
- finally, it all comes together with the top-level "Agents" API that allows you to create agents that can use the tools to answer questions, perform tasks, and more. - finally, it all comes together with the top-level "Agents" API that allows you to create agents that can use the tools to answer questions, perform tasks, and more.
@ -86,7 +86,7 @@ from llama_stack_client.lib.agents.agent import Agent
# Configure agent with memory # Configure agent with memory
agent_config = AgentConfig( agent_config = AgentConfig(
model="meta-llama/Llama-3.2-3B-Instruct", model="meta-llama/Llama-3.3-70B-Instruct",
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
enable_session_persistence=False, enable_session_persistence=False,
toolgroups=[ toolgroups=[
@ -102,6 +102,19 @@ agent_config = AgentConfig(
agent = Agent(client, agent_config) agent = Agent(client, agent_config)
session_id = agent.create_session("rag_session") session_id = agent.create_session("rag_session")
# Ask questions about documents in the vector db, and the agent will query the db to answer the question.
response = agent.create_turn(
messages=[{"role": "user", "content": "How to optimize memory in PyTorch?"}],
session_id=session_id,
)
```
> **NOTE:** the `instructions` field in the `AgentConfig` can be used to guide the agent's behavior. It is important to experiment with different instructions to see what works best for your use case.
You can also pass documents along with the user's message and ask questions about them.
```python
# Initial document ingestion # Initial document ingestion
response = agent.create_turn( response = agent.create_turn(
messages=[ messages=[

View file

@ -83,15 +83,15 @@ result = client.tool_runtime.invoke_tool(
) )
``` ```
#### Memory #### RAG
The Memory tool enables retrieval of context from various types of memory banks (vector, key-value, keyword, and graph). The RAG tool enables retrieval of context from various types of memory banks (vector, key-value, keyword, and graph).
```python ```python
# Register Memory tool group # Register Memory tool group
client.toolgroups.register( client.toolgroups.register(
toolgroup_id="builtin::memory", toolgroup_id="builtin::rag",
provider_id="memory", provider_id="faiss",
args={"max_chunks": 5, "max_tokens_in_context": 4096}, args={"max_chunks": 5, "max_tokens_in_context": 4096},
) )
``` ```
@ -102,7 +102,7 @@ Features:
- Context retrieval with token limits - Context retrieval with token limits
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, code interpreter and memory, that are provided by tavily-search, code-interpreter and memory providers. > **Note:** By default, llama stack run.yaml defines toolgroups for web search, code interpreter and rag, that are provided by tavily-search, code-interpreter and rag providers.
## Model Context Protocol (MCP) Tools ## Model Context Protocol (MCP) Tools
@ -125,50 +125,43 @@ MCP tools require:
- Tools are discovered dynamically from the endpoint - Tools are discovered dynamically from the endpoint
## Tools provided by the client ## Adding Custom Tools
These tools are registered along with the agent config and are specific to the agent for which they are registered. The main difference between these tools and the tools provided by the built-in providers is that the execution of these tools is handled by the client and the agent transfers the tool call to the client and waits for the result from the client. When you want to use tools other than the built-in tools, you can implement a python function and decorate it with `@client_tool`.
To define a custom tool, you need to use the `@client_tool` decorator.
```python
from llama_stack_client.lib.agents.client_tool import client_tool
# Example tool definition
@client_tool
def my_tool(input: int) -> int:
"""
Runs my awesome tool.
:param input: some int parameter
"""
return input * 2
```
> **NOTE:** We employ python docstrings to describe the tool and the parameters. It is important to document the tool and the parameters so that the model can use the tool correctly. It is recommended to experiment with different docstrings to see how they affect the model's behavior.
Once defined, simply pass the tool to the agent config. `Agent` will take care of the rest (calling the model with the tool definition, executing the tool, and returning the result to the model for the next iteration).
```python ```python
# Example agent config with client provided tools # Example agent config with client provided tools
config = AgentConfig( client_tools = [
toolgroups=[ my_tool,
"builtin::websearch", ]
],
client_tools=[ToolDef(name="client_tool", description="Client provided tool")], agent_config = AgentConfig(
...,
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
) )
agent = Agent(client, agent_config, client_tools)
``` ```
Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools. Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools.
## Tool Structure
Each tool has the following components:
- `name`: Unique identifier for the tool
- `description`: Human-readable description of the tool's functionality
- `parameters`: List of parameters the tool accepts
- `name`: Parameter name
- `parameter_type`: Data type (string, number, etc.)
- `description`: Parameter description
- `required`: Whether the parameter is required (default: true)
- `default`: Default value if any
Example tool definition:
```python
{
"name": "web_search",
"description": "Search the web for information",
"parameters": [
{
"name": "query",
"parameter_type": "string",
"description": "The query to search for",
"required": True,
}
],
}
```
## Tool Invocation ## Tool Invocation

View file

@ -30,7 +30,7 @@ Build a Llama stack container
options: options:
-h, --help show this help message and exit -h, --help show this help message and exit
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml. --config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml.
If this argument is not provided, you will be prompted to enter information interactively If this argument is not provided, you will be prompted to enter information interactively
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates --template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates
--list-templates Show the available templates for building a Llama Stack distribution --list-templates Show the available templates for building a Llama Stack distribution
@ -106,7 +106,7 @@ It would be best to start with a template and understand the structure of the co
llama stack build llama stack build
> Enter a name for your Llama Stack (e.g. my-local-stack): my-stack > Enter a name for your Llama Stack (e.g. my-local-stack): my-stack
> Enter the image type you want your Llama Stack to be built as (container or conda): conda > Enter the image type you want your Llama Stack to be built as (container or conda or venv): conda
Llama Stack is composed of several APIs working together. Let's select Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs. the provider types (implementations) you want to use for these APIs.
@ -187,7 +187,7 @@ usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-i
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}] [--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
config config
start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution. Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
positional arguments: positional arguments:
config Path to config file to use for the run config Path to config file to use for the run

View file

@ -27,16 +27,19 @@ The following environment variables can be configured:
The following models are available by default: The following models are available by default:
- `meta-llama/Llama-3-8B-Instruct (meta/llama3-8b-instruct)` - `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
- `meta-llama/Llama-3-70B-Instruct (meta/llama3-70b-instruct)` - `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
- `meta-llama/Llama-3.1-8B-Instruct (meta/llama-3.1-8b-instruct)` - `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta-llama/Llama-3.1-70B-Instruct (meta/llama-3.1-70b-instruct)` - `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (meta/llama-3.1-405b-instruct)` - `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
- `meta-llama/Llama-3.2-1B-Instruct (meta/llama-3.2-1b-instruct)` - `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
- `meta-llama/Llama-3.2-3B-Instruct (meta/llama-3.2-3b-instruct)` - `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `meta-llama/Llama-3.2-11B-Vision-Instruct (meta/llama-3.2-11b-vision-instruct)` - `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `meta-llama/Llama-3.2-90B-Vision-Instruct (meta/llama-3.2-90b-vision-instruct)` - `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
- `baai/bge-m3 (baai/bge-m3)` - `nvidia/llama-3.2-nv-embedqa-1b-v2 `
- `nvidia/nv-embedqa-e5-v5 `
- `nvidia/nv-embedqa-mistral-7b-v2 `
- `snowflake/arctic-embed-l `
### Prerequisite: API Keys ### Prerequisite: API Keys

View file

@ -34,9 +34,9 @@ The following environment variables can be configured:
The following models are available by default: The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (meta.llama3-1-8b-instruct-v1:0)` - `meta.llama3-1-8b-instruct-v1:0 (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta-llama/Llama-3.1-70B-Instruct (meta.llama3-1-70b-instruct-v1:0)` - `meta.llama3-1-70b-instruct-v1:0 (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (meta.llama3-1-405b-instruct-v1:0)` - `meta.llama3-1-405b-instruct-v1:0 (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
### Prerequisite: API Keys ### Prerequisite: API Keys

View file

@ -27,8 +27,8 @@ The following environment variables can be configured:
The following models are available by default: The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (llama3.1-8b)` - `llama3.1-8b (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta-llama/Llama-3.3-70B-Instruct (llama-3.3-70b)` - `llama-3.3-70b (aliases: meta-llama/Llama-3.3-70B-Instruct)`
### Prerequisite: API Keys ### Prerequisite: API Keys

View file

@ -37,17 +37,17 @@ The following environment variables can be configured:
The following models are available by default: The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (accounts/fireworks/models/llama-v3p1-8b-instruct)` - `accounts/fireworks/models/llama-v3p1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta-llama/Llama-3.1-70B-Instruct (accounts/fireworks/models/llama-v3p1-70b-instruct)` - `accounts/fireworks/models/llama-v3p1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (accounts/fireworks/models/llama-v3p1-405b-instruct)` - `accounts/fireworks/models/llama-v3p1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
- `meta-llama/Llama-3.2-1B-Instruct (accounts/fireworks/models/llama-v3p2-1b-instruct)` - `accounts/fireworks/models/llama-v3p2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
- `meta-llama/Llama-3.2-3B-Instruct (accounts/fireworks/models/llama-v3p2-3b-instruct)` - `accounts/fireworks/models/llama-v3p2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `meta-llama/Llama-3.2-11B-Vision-Instruct (accounts/fireworks/models/llama-v3p2-11b-vision-instruct)` - `accounts/fireworks/models/llama-v3p2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `meta-llama/Llama-3.2-90B-Vision-Instruct (accounts/fireworks/models/llama-v3p2-90b-vision-instruct)` - `accounts/fireworks/models/llama-v3p2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
- `meta-llama/Llama-3.3-70B-Instruct (accounts/fireworks/models/llama-v3p3-70b-instruct)` - `accounts/fireworks/models/llama-v3p3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `meta-llama/Llama-Guard-3-8B (accounts/fireworks/models/llama-guard-3-8b)` - `accounts/fireworks/models/llama-guard-3-8b (aliases: meta-llama/Llama-Guard-3-8B)`
- `meta-llama/Llama-Guard-3-11B-Vision (accounts/fireworks/models/llama-guard-3-11b-vision)` - `accounts/fireworks/models/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
- `nomic-ai/nomic-embed-text-v1.5 (nomic-ai/nomic-embed-text-v1.5)` - `nomic-ai/nomic-embed-text-v1.5 `
### Prerequisite: API Keys ### Prerequisite: API Keys

View file

@ -37,11 +37,11 @@ The following environment variables can be configured:
The following models are available by default: The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (llama3-8b-8192)` - `groq/llama3-8b-8192 (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta-llama/Llama-3.1-8B-Instruct (llama-3.1-8b-instant)` - `groq/llama-3.1-8b-instant `
- `meta-llama/Llama-3-70B-Instruct (llama3-70b-8192)` - `groq/llama3-70b-8192 (aliases: meta-llama/Llama-3-70B-Instruct)`
- `meta-llama/Llama-3.3-70B-Instruct (llama-3.3-70b-versatile)` - `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `meta-llama/Llama-3.2-3B-Instruct (llama-3.2-3b-preview)` - `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
### Prerequisite: API Keys ### Prerequisite: API Keys

View file

@ -41,12 +41,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ llama model list --downloaded
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M ┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
``` ```
## Running the Distribution ## Running the Distribution

View file

@ -41,12 +41,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ llama model list --downloaded
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M ┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
``` ```
## Running the Distribution ## Running the Distribution

View file

@ -22,7 +22,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
@ -141,17 +141,21 @@ ollama run <model_name>
To make sure that the model is being served correctly, run `ollama ps` to get a list of models being served by ollama. To make sure that the model is being served correctly, run `ollama ps` to get a list of models being served by ollama.
``` ```
$ ollama ps $ ollama ps
NAME ID SIZE PROCESSOR UNTIL
NAME ID SIZE PROCESSOR UNTIL llama3.2:3b-instruct-fp16 195a8c01d91e 8.6 GB 100% GPU 9 minutes from now
llama3.1:8b-instruct-fp16 4aacac419454 17 GB 100% GPU 4 minutes from now
``` ```
To verify that the model served by ollama is correctly connected to Llama Stack server To verify that the model served by ollama is correctly connected to Llama Stack server
```bash ```bash
$ llama-stack-client models list $ llama-stack-client models list
+----------------------+----------------------+---------------+-----------------------------------------------+
| identifier | llama_model | provider_id | metadata | Available Models
+======================+======================+===============+===============================================+
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | ollama0 | {'ollama_model': 'llama3.1:8b-instruct-fp16'} | ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┓
+----------------------+----------------------+---------------+-----------------------------------------------+ ┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ llm │ meta-llama/Llama-3.2-3B-Instruct │ llama3.2:3b-instruct-fp16 │ │ ollama │
└──────────────┴──────────────────────────────────────┴──────────────────────────────┴───────────┴─────────────┘
Total models: 1
``` ```

View file

@ -34,15 +34,15 @@ The following environment variables can be configured:
The following models are available by default: The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (Meta-Llama-3.1-8B-Instruct)` - `Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta-llama/Llama-3.1-70B-Instruct (Meta-Llama-3.1-70B-Instruct)` - `Meta-Llama-3.1-70B-Instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (Meta-Llama-3.1-405B-Instruct)` - `Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
- `meta-llama/Llama-3.2-1B-Instruct (Meta-Llama-3.2-1B-Instruct)` - `Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
- `meta-llama/Llama-3.2-3B-Instruct (Meta-Llama-3.2-3B-Instruct)` - `Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `meta-llama/Llama-3.3-70B-Instruct (Meta-Llama-3.3-70B-Instruct)` - `Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `meta-llama/Llama-3.2-11B-Vision-Instruct (Llama-3.2-11B-Vision-Instruct)` - `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `meta-llama/Llama-3.2-90B-Vision-Instruct (Llama-3.2-90B-Vision-Instruct)` - `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
- `meta-llama/Llama-Guard-3-8B (Meta-Llama-Guard-3-8B)` - `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
### Prerequisite: API Keys ### Prerequisite: API Keys

View file

@ -37,17 +37,17 @@ The following environment variables can be configured:
The following models are available by default: The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct` - `meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta-llama/Llama-3.1-70B-Instruct` - `meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `meta-llama/Llama-3.1-405B-Instruct-FP8` - `meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
- `meta-llama/Llama-3.2-3B-Instruct` - `meta-llama/Llama-3.2-3B-Instruct-Turbo (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `meta-llama/Llama-3.2-11B-Vision-Instruct` - `meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `meta-llama/Llama-3.2-90B-Vision-Instruct` - `meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
- `meta-llama/Llama-3.3-70B-Instruct` - `meta-llama/Llama-3.3-70B-Instruct-Turbo (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `meta-llama/Llama-Guard-3-8B` - `meta-llama/Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
- `meta-llama/Llama-Guard-3-11B-Vision` - `meta-llama/Llama-Guard-3-11B-Vision-Turbo (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
- `togethercomputer/m2-bert-80M-8k-retrieval` - `togethercomputer/m2-bert-80M-8k-retrieval `
- `togethercomputer/m2-bert-80M-32k-retrieval` - `togethercomputer/m2-bert-80M-32k-retrieval `
### Prerequisite: API Keys ### Prerequisite: API Keys

View file

@ -38,7 +38,7 @@ The API is **exactly identical** for both clients.
:::{dropdown} Starting up the Llama Stack server :::{dropdown} Starting up the Llama Stack server
The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc. The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc.
To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image. To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image. If you'd like to build your own image or customize the configurations, please check out [this guide](../references/index.md).
Lets setup some environment variables that we will use in the rest of the guide. Lets setup some environment variables that we will use in the rest of the guide.
```bash ```bash
@ -102,12 +102,18 @@ Let's use the `llama-stack-client` CLI to check the connectivity to the server.
$ llama-stack-client configure --endpoint http://localhost:$LLAMA_STACK_PORT $ llama-stack-client configure --endpoint http://localhost:$LLAMA_STACK_PORT
> Enter the API key (leave empty if no key is needed): > Enter the API key (leave empty if no key is needed):
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321 Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
$ llama-stack-client models list $ llama-stack-client models list
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ identifier ┃ provider_id ┃ provider_resource_id ┃ metadata ┃ Available Models
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│ meta-llama/Llama-3.2-3B-Instruct │ ollama │ llama3.2:3b-instruct-fp16 │ │ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┓
└──────────────────────────────────┴─────────────┴───────────────────────────┴──────────┘ ┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ llm │ meta-llama/Llama-3.2-3B-Instruct │ llama3.2:3b-instruct-fp16 │ │ ollama │
└──────────────┴──────────────────────────────────────┴──────────────────────────────┴───────────┴─────────────┘
Total models: 1
``` ```
You can test basic Llama inference completion using the CLI too. You can test basic Llama inference completion using the CLI too.
@ -255,7 +261,7 @@ rag_agent = Agent(client, agent_config)
session_id = rag_agent.create_session("test-session") session_id = rag_agent.create_session("test-session")
user_prompts = [ user_prompts = [
"What are the top 5 topics that were explained? Only list succinct bullet points.", "How to optimize memory usage in torchtune? use the knowledge_search tool to get information.",
] ]
# Run the agent loop by calling the `create_turn` method # Run the agent loop by calling the `create_turn` method

View file

@ -36,7 +36,7 @@ Evaluates the outputs of the system.
Collects telemetry data from the system. Collects telemetry data from the system.
## Tool Runtime ## Tool Runtime
Is associated with the ToolGroup resouces. Is associated with the ToolGroup resouces.
## Vector IO ## Vector IO

View file

@ -1,10 +1,10 @@
--- ---
orphan: true orphan: true
--- ---
# Chroma # Chroma
[Chroma](https://www.trychroma.com/) is an inline and remote vector [Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database. database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
That means you're not limited to storing vectors in memory or in a separate service. That means you're not limited to storing vectors in memory or in a separate service.
## Features ## Features

View file

@ -3,7 +3,7 @@ orphan: true
--- ---
# Faiss # Faiss
[Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It [Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory. allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval. That means you'll get fast and efficient vector retrieval.
@ -29,5 +29,5 @@ You can install Faiss using pip:
pip install faiss-cpu pip install faiss-cpu
``` ```
## Documentation ## Documentation
See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for
more details about Faiss in general. more details about Faiss in general.

View file

@ -3,7 +3,7 @@ orphan: true
--- ---
# Postgres PGVector # Postgres PGVector
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It [PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory. allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval. That means you'll get fast and efficient vector retrieval.

View file

@ -3,7 +3,7 @@ orphan: true
--- ---
# Qdrant # Qdrant
[Qdrant](https://qdrant.tech/documentation/) is a remote vector database provider for Llama Stack. It [Qdrant](https://qdrant.tech/documentation/) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory. allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval. That means you'll get fast and efficient vector retrieval.

View file

@ -3,8 +3,8 @@ orphan: true
--- ---
# SQLite-Vec # SQLite-Vec
[SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It [SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It
allows you to store and query vectors directly within an SQLite database. allows you to store and query vectors directly within an SQLite database.
That means you're not limited to storing vectors in memory or in a separate service. That means you're not limited to storing vectors in memory or in a separate service.
## Features ## Features

View file

@ -1,10 +1,10 @@
--- ---
orphan: true orphan: true
--- ---
# Weaviate # Weaviate
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack. [Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
It allows you to store and query vectors directly within a Weaviate database. It allows you to store and query vectors directly within a Weaviate database.
That means you're not limited to storing vectors in memory or in a separate service. That means you're not limited to storing vectors in memory or in a separate service.
## Features ## Features
@ -27,7 +27,7 @@ To use Weaviate in your Llama Stack project, follow these steps:
## Installation ## Installation
To install Weaviate see the [Weaviate quickstart documentation](https://weaviate.io/developers/weaviate/quickstart). To install Weaviate see the [Weaviate quickstart documentation](https://weaviate.io/developers/weaviate/quickstart).
## Documentation ## Documentation
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general. See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.

View file

@ -129,3 +129,35 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). **Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. > **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
## List the downloaded models
To list the downloaded models with the following command:
```
llama model list --downloaded
```
You should see a table like this:
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```

View file

@ -154,6 +154,38 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. > **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
## List the downloaded models
To list the downloaded models with the following command:
```
llama model list --downloaded
```
You should see a table like this:
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```
## Understand the models ## Understand the models
The `llama model` command helps you explore the models interface. The `llama model` command helps you explore the models interface.

View file

@ -58,11 +58,15 @@ llama-stack-client providers list
llama-stack-client models list llama-stack-client models list
``` ```
``` ```
+----------------------+----------------------+---------------+----------------------------------------------------------+ Available Models
| identifier | llama_model | provider_id | metadata |
+======================+======================+===============+==========================================================+ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┓
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | tgi0 | {'huggingface_repo': 'meta-llama/Llama-3.1-8B-Instruct'} | ┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
+----------------------+----------------------+---------------+----------------------------------------------------------+ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ llm │ meta-llama/Llama-3.2-3B-Instruct │ llama3.2:3b-instruct-fp16 │ │ ollama │
└──────────────┴──────────────────────────────────────┴──────────────────────────────┴───────────┴─────────────┘
Total models: 1
``` ```
### `llama-stack-client models get` ### `llama-stack-client models get`

View file

@ -73,7 +73,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
Open a new terminal and install `llama-stack`: Open a new terminal and install `llama-stack`:
```bash ```bash
conda activate ollama conda activate ollama
pip install llama-stack==0.1.0 pip install -U llama-stack
``` ```
--- ---

View file

@ -5,6 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Optional
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -33,3 +36,20 @@ class Api(Enum):
# built-in API # built-in API
inspect = "inspect" inspect = "inspect"
@json_schema_type
class Error(BaseModel):
"""
Error response from the API. Roughly follows RFC 7807.
:param status: HTTP status code
:param title: Error title, a short summary of the error which is invariant for an error type
:param detail: Error detail, a longer human-readable description of the error
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
"""
status: int
title: str
detail: str
instance: Optional[str] = None

View file

@ -9,6 +9,7 @@ import argparse
from .download import Download from .download import Download
from .model import ModelParser from .model import ModelParser
from .stack import StackParser from .stack import StackParser
from .stack.utils import print_subcommand_description
from .verify_download import VerifyDownload from .verify_download import VerifyDownload
@ -20,6 +21,7 @@ class LlamaCLIParser:
prog="llama", prog="llama",
description="Welcome to the Llama CLI", description="Welcome to the Llama CLI",
add_help=True, add_help=True,
formatter_class=argparse.RawTextHelpFormatter,
) )
# Default command is to print help # Default command is to print help
@ -33,6 +35,8 @@ class LlamaCLIParser:
Download.create(subparsers) Download.create(subparsers)
VerifyDownload.create(subparsers) VerifyDownload.create(subparsers)
print_subcommand_description(self.parser, subparsers)
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args() return self.parser.parse_args()

View file

@ -12,6 +12,7 @@ from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.remove import ModelRemove from llama_stack.cli.model.remove import ModelRemove
from llama_stack.cli.model.verify_download import ModelVerifyDownload from llama_stack.cli.model.verify_download import ModelVerifyDownload
from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
@ -24,6 +25,7 @@ class ModelParser(Subcommand):
"model", "model",
prog="llama model", prog="llama model",
description="Work with llama models", description="Work with llama models",
formatter_class=argparse.RawTextHelpFormatter,
) )
self.parser.set_defaults(func=lambda args: self.parser.print_help()) self.parser.set_defaults(func=lambda args: self.parser.print_help())
@ -37,3 +39,5 @@ class ModelParser(Subcommand):
ModelDescribe.create(subparsers) ModelDescribe.create(subparsers)
ModelVerifyDownload.create(subparsers) ModelVerifyDownload.create(subparsers)
ModelRemove.create(subparsers) ModelRemove.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -7,10 +7,14 @@
import argparse import argparse
import textwrap import textwrap
from io import StringIO from io import StringIO
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
ROOT_DIR = Path(__file__).parent.parent
class ModelPromptFormat(Subcommand): class ModelPromptFormat(Subcommand):
"""Llama model cli for describe a model prompt format (message formats)""" """Llama model cli for describe a model prompt format (message formats)"""
@ -48,7 +52,26 @@ class ModelPromptFormat(Subcommand):
supported_model_ids = [ supported_model_ids = [
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2} m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
] ]
model_str = "\n".join([m.value for m in supported_model_ids])
model_list = [m.value for m in supported_model_ids]
model_str = "\n".join(model_list)
if args.list:
headers = ["Model(s)"]
rows = []
for m in model_list:
rows.append(
[
m,
]
)
print_table(
rows,
headers,
separate_rows=True,
)
return
try: try:
model_id = CoreModelId(args.model_name) model_id = CoreModelId(args.model_name)
except ValueError: except ValueError:
@ -57,9 +80,9 @@ class ModelPromptFormat(Subcommand):
if model_id not in supported_model_ids: if model_id not in supported_model_ids:
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}") self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
llama_3_1_file = importlib.resources.files("llama_models") / "llama3_1/prompt_format.md" llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
llama_3_2_text_file = importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md" llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
llama_3_2_vision_file = importlib.resources.files("llama_models") / "llama3_2/vision_prompt_format.md" llama_3_2_vision_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "vision_prompt_format.md"
if model_family(model_id) == ModelFamily.llama3_1: if model_family(model_id) == ModelFamily.llama3_1:
with importlib.resources.as_file(llama_3_1_file) as f: with importlib.resources.as_file(llama_3_1_file) as f:
content = f.open("r").read() content = f.open("r").read()

View file

@ -38,7 +38,7 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, in_notebook, run_with_pty from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -65,8 +65,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
if args.image_type == "venv": if args.image_type == "venv":
current_venv = os.environ.get("VIRTUAL_ENV") current_venv = os.environ.get("VIRTUAL_ENV")
image_name = args.image_name or current_venv image_name = args.image_name or current_venv
if not image_name and in_notebook():
image_name = "__system__"
elif args.image_type == "conda": elif args.image_type == "conda":
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV") current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = args.image_name or current_conda_env image_name = args.image_name or current_conda_env
@ -143,7 +141,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
completer=WordCompleter(available_providers), completer=WordCompleter(available_providers),
complete_while_typing=True, complete_while_typing=True,
validator=Validator.from_callable( validator=Validator.from_callable(
lambda x: x in available_providers, lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
error_message="Invalid provider, use <TAB> to see options", error_message="Invalid provider, use <TAB> to see options",
), ),
) )
@ -291,6 +289,8 @@ def _run_stack_build_command_from_build_config(
if not image_name: if not image_name:
raise ValueError("Please specify an image name when building a conda image") raise ValueError("Please specify an image name when building a conda image")
elif build_config.image_type == ImageType.venv.value: elif build_config.image_type == ImageType.venv.value:
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
image_name = "__system__"
if not image_name: if not image_name:
raise ValueError("Please specify an image name when building a venv image") raise ValueError("Please specify an image name when building a venv image")

View file

@ -26,7 +26,7 @@ class StackBuild(Subcommand):
"--config", "--config",
type=str, type=str,
default=None, default=None,
help="Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively", help="Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
) )
self.parser.add_argument( self.parser.add_argument(

View file

@ -7,6 +7,7 @@
import argparse import argparse
from importlib.metadata import version from importlib.metadata import version
from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from .build import StackBuild from .build import StackBuild
@ -22,6 +23,7 @@ class StackParser(Subcommand):
"stack", "stack",
prog="llama stack", prog="llama stack",
description="Operations for the Llama Stack / Distributions", description="Operations for the Llama Stack / Distributions",
formatter_class=argparse.RawTextHelpFormatter,
) )
self.parser.add_argument( self.parser.add_argument(
@ -39,3 +41,5 @@ class StackParser(Subcommand):
StackListApis.create(subparsers) StackListApis.create(subparsers)
StackListProviders.create(subparsers) StackListProviders.create(subparsers)
StackRun.create(subparsers) StackRun.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
def print_subcommand_description(parser, subparsers):
"""Print descriptions of subcommands."""
description_text = ""
for name, subcommand in subparsers.choices.items():
description = subcommand.description
description_text += f" {name:<21} {description}\n"
parser.epilog = description_text

View file

@ -112,7 +112,7 @@ def test_parse_and_maybe_upgrade_config_old_format(old_config):
inference_providers = result.providers["inference"] inference_providers = result.providers["inference"]
assert len(inference_providers) == 2 assert len(inference_providers) == 2
assert set(x.provider_id for x in inference_providers) == { assert {x.provider_id for x in inference_providers} == {
"remote::ollama-00", "remote::ollama-00",
"meta-reference-01", "meta-reference-01",
} }

View file

@ -15,7 +15,6 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_command, run_with_pty from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -103,8 +102,6 @@ def build_image(
template_or_config, template_or_config,
image_name, image_name,
container_base, container_base,
str(build_file_path),
str(BUILDS_BASE_DIR / ImageType.container.value),
" ".join(normal_deps), " ".join(normal_deps),
] ]
elif build_config.image_type == ImageType.conda.value: elif build_config.image_type == ImageType.conda.value:

View file

@ -6,8 +6,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out # This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694 # Reference: https://github.com/astral-sh/uv/pull/1694
@ -16,8 +16,8 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR" echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR" echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi fi
if [ "$#" -lt 3 ]; then if [ "$#" -lt 3 ]; then
@ -52,7 +52,7 @@ ensure_conda_env_python310() {
local python_version="3.10" local python_version="3.10"
# Check if conda command is available # Check if conda command is available
if ! command -v conda &>/dev/null; then if ! is_command_available conda; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2 printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1 exit 1
fi fi
@ -87,8 +87,6 @@ ensure_conda_env_python310() {
# these packages are damaged in test-pypi, so install them first # these packages are damaged in test-pypi, so install them first
uv pip install fastapi libcst uv pip install fastapi libcst
uv pip install --extra-index-url https://test.pypi.org/simple/ \ uv pip install --extra-index-url https://test.pypi.org/simple/ \
llama-models==$TEST_PYPI_VERSION \
llama-stack-client==$TEST_PYPI_VERSION \
llama-stack==$TEST_PYPI_VERSION \ llama-stack==$TEST_PYPI_VERSION \
$pip_dependencies $pip_dependencies
if [ -n "$special_pip_deps" ]; then if [ -n "$special_pip_deps" ]; then
@ -111,22 +109,21 @@ ensure_conda_env_python310() {
else else
PYPI_VERSION="${PYPI_VERSION:-}" PYPI_VERSION="${PYPI_VERSION:-}"
if [ -n "$PYPI_VERSION" ]; then if [ -n "$PYPI_VERSION" ]; then
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}" SPEC_VERSION="llama-stack==${PYPI_VERSION}"
else else
SPEC_VERSION="llama-stack" SPEC_VERSION="llama-stack"
fi fi
uv pip install --no-cache-dir $SPEC_VERSION uv pip install --no-cache-dir $SPEC_VERSION
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2 printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
exit 1 exit 1
fi fi
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n" printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
uv pip uninstall llama-models uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
fi fi
# Install pip dependencies # Install pip dependencies

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
@ -6,7 +6,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-} LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
@ -20,26 +19,27 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead # mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-} USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
if [ "$#" -lt 6 ]; then if [ "$#" -lt 4 ]; then
# This only works for templates # This only works for templates
echo "Usage: $0 <template_or_config> <image_name> <container_base> <build_file_path> <host_build_dir> <pip_dependencies> [<special_pip_deps>]" >&2 echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
exit 1 exit 1
fi fi
set -euo pipefail set -euo pipefail
template_or_config="$1" template_or_config="$1"
image_name="$2" shift
container_base="$3" image_name="$1"
build_file_path="$4" shift
host_build_dir="$5" container_base="$1"
pip_dependencies="$6" shift
special_pip_deps="${7:-}" pip_dependencies="$1"
shift
special_pip_deps="${1:-}"
# Define color codes # Define color codes
RED='\033[0;31m' RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color NC='\033[0m' # No Color
CONTAINER_BINARY=${CONTAINER_BINARY:-docker} CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
@ -47,8 +47,10 @@ CONTAINER_OPTS=${CONTAINER_OPTS:-}
TEMP_DIR=$(mktemp -d) TEMP_DIR=$(mktemp -d)
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
add_to_container() { add_to_container() {
local input
output_file="$TEMP_DIR/Containerfile" output_file="$TEMP_DIR/Containerfile"
if [ -t 0 ]; then if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file" printf '%s\n' "$1" >>"$output_file"
@ -58,15 +60,21 @@ add_to_container() {
fi fi
} }
# Check if container command is available
if ! is_command_available $CONTAINER_BINARY; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1
fi
# Update and install UBI9 components if UBI9 base image is used # Update and install UBI9 components if UBI9 base image is used
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
add_to_container << EOF add_to_container << EOF
FROM $container_base FROM $container_base
WORKDIR /app WORKDIR /app
RUN microdnf -y update && microdnf install -y iputils net-tools wget \ RUN dnf -y update && dnf install -y iputils net-tools wget \
vim-minimal python3.11 python3.11-pip python3.11-wheel \ vim-minimal python3.11 python3.11-pip python3.11-wheel \
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1 ENV UV_SYSTEM_PYTHON=1
RUN pip install uv RUN pip install uv
@ -107,7 +115,6 @@ EOF
fi fi
stack_mount="/app/llama-stack-source" stack_mount="/app/llama-stack-source"
models_mount="/app/llama-models-source"
client_mount="/app/llama-stack-client-source" client_mount="/app/llama-stack-client-source"
install_local_package() { install_local_package() {
@ -131,10 +138,6 @@ EOF
} }
if [ -n "$LLAMA_MODELS_DIR" ]; then
install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR" install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
fi fi
@ -150,12 +153,12 @@ EOF
add_to_container << EOF add_to_container << EOF
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \ RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \ --index-strategy unsafe-best-match \
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
EOF EOF
else else
if [ -n "$PYPI_VERSION" ]; then if [ -n "$PYPI_VERSION" ]; then
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}" SPEC_VERSION="llama-stack==${PYPI_VERSION}"
else else
SPEC_VERSION="llama-stack" SPEC_VERSION="llama-stack"
fi fi
@ -165,6 +168,11 @@ EOF
fi fi
fi fi
# remove uv after installation
add_to_container << EOF
RUN pip uninstall -y uv
EOF
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag # if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
if [[ "$template_or_config" != *.yaml ]]; then if [[ "$template_or_config" != *.yaml ]]; then
add_to_container << EOF add_to_container << EOF
@ -185,26 +193,28 @@ RUN mkdir -p /.llama /.cache
RUN chmod -R g+rw /app /.llama /.cache RUN chmod -R g+rw /app /.llama /.cache
EOF EOF
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n" printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
cat $TEMP_DIR/Containerfile cat "$TEMP_DIR"/Containerfile
printf "\n" printf "\n"
mounts="" # Start building the CLI arguments
CLI_ARGS=()
# Read CONTAINER_OPTS and put it in an array
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount" CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
fi fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount" CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
fi fi
fi fi
if command -v selinuxenabled &>/dev/null && selinuxenabled; then if is_command_available selinuxenabled && selinuxenabled; then
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir # Disable SELinux labels -- we don't want to relabel the llama-stack source dir
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable" CLI_ARGS+=("--security-opt" "label=disable")
fi fi
# Set version tag based on PyPI version # Set version tag based on PyPI version
@ -212,7 +222,7 @@ if [ -n "$PYPI_VERSION" ]; then
version_tag="$PYPI_VERSION" version_tag="$PYPI_VERSION"
elif [ -n "$TEST_PYPI_VERSION" ]; then elif [ -n "$TEST_PYPI_VERSION" ]; then
version_tag="test-$TEST_PYPI_VERSION" version_tag="test-$TEST_PYPI_VERSION"
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_MODELS_DIR" ]]; then elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_STACK_CLIENT_DIR" ]]; then
version_tag="dev" version_tag="dev"
else else
URL="https://pypi.org/pypi/llama-stack/json" URL="https://pypi.org/pypi/llama-stack/json"
@ -225,11 +235,11 @@ image_tag="$image_name:$version_tag"
# Detect platform architecture # Detect platform architecture
ARCH=$(uname -m) ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then if [ -n "$BUILD_PLATFORM" ]; then
PLATFORM="--platform $BUILD_PLATFORM" CLI_ARGS+=("--platform $BUILD_PLATFORM")
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
PLATFORM="--platform linux/arm64" CLI_ARGS+=("--platform" "linux/arm64")
elif [ "$ARCH" = "x86_64" ]; then elif [ "$ARCH" = "x86_64" ]; then
PLATFORM="--platform linux/amd64" CLI_ARGS+=("--platform" "linux/amd64")
else else
echo "Unsupported architecture: $ARCH" echo "Unsupported architecture: $ARCH"
exit 1 exit 1
@ -238,8 +248,13 @@ fi
echo "PWD: $(pwd)" echo "PWD: $(pwd)"
echo "Containerfile: $TEMP_DIR/Containerfile" echo "Containerfile: $TEMP_DIR/Containerfile"
set -x set -x
$CONTAINER_BINARY build $CONTAINER_OPTS $PLATFORM -t $image_tag \
-f "$TEMP_DIR/Containerfile" "." $mounts --progress=plain $CONTAINER_BINARY build \
"${CLI_ARGS[@]}" \
-t "$image_tag" \
-f "$TEMP_DIR/Containerfile" \
"." \
--progress=plain
# clean up tmp/configs # clean up tmp/configs
set +x set +x

View file

@ -9,8 +9,8 @@
# TODO: combine this with build_conda_env.sh since it is almost identical # TODO: combine this with build_conda_env.sh since it is almost identical
# the only difference is that we don't do any conda-specific setup # the only difference is that we don't do any conda-specific setup
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out # This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694 # Reference: https://github.com/astral-sh/uv/pull/1694
@ -21,8 +21,8 @@ VIRTUAL_ENV=${VIRTUAL_ENV:-}
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR" echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-models-dir=$LLAMA_MODELS_DIR" echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi fi
if [ "$#" -lt 2 ]; then if [ "$#" -lt 2 ]; then
@ -95,7 +95,7 @@ run() {
# we are building a command line so word splitting is expected # we are building a command line so word splitting is expected
uv pip install --extra-index-url https://test.pypi.org/simple/ \ uv pip install --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \ --index-strategy unsafe-best-match \
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \ llama-stack=="$TEST_PYPI_VERSION" \
$pip_dependencies $pip_dependencies
if [ -n "$special_pip_deps" ]; then if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps" IFS='#' read -ra parts <<<"$special_pip_deps"
@ -120,15 +120,14 @@ run() {
uv pip install --no-cache-dir llama-stack uv pip install --no-cache-dir llama-stack
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_MODELS_DIR" >&2 printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
exit 1 exit 1
fi fi
printf "Installing from LLAMA_MODELS_DIR: %s\n" "$LLAMA_MODELS_DIR" printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
uv pip uninstall llama-models uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
fi fi
# Install pip dependencies # Install pip dependencies

View file

@ -1,47 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
set -euo pipefail
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 2 ]; then
echo "Usage: $0 <container name> <build file path>"
exit 1
fi
container_image="$1"
host_build_dir="$2"
container_build_dir="/app/builds"
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
# Disable SELinux labels
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
fi
mounts=""
if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
fi
set -x
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
--entrypoint "/usr/local/bin/llama" \
-v $host_build_dir:$container_build_dir \
$mounts \
$container_image \
stack configure ./llamastack-build.yaml --output-dir $container_build_dir

View file

@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]: def stack_apis() -> List[Api]:
return [v for v in Api] return list(Api)
class AutoRoutedApiInfo(BaseModel): class AutoRoutedApiInfo(BaseModel):
@ -55,7 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]: def providable_apis() -> List[Api]:
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis()) routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect] return [api for api in Api if api not in routing_table_apis and api != Api.inspect]

View file

@ -104,7 +104,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
logger.warning( logger.warning(
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}", f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
) )
return value raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
class LlamaStackAsLibraryClient(LlamaStackClient): class LlamaStackAsLibraryClient(LlamaStackClient):
@ -324,6 +324,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
await end_trace() await end_trace()
json_content = json.dumps(convert_pydantic_to_json_value(result)) json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response( mock_response = httpx.Response(
status_code=httpx.codes.OK, status_code=httpx.codes.OK,
content=json_content.encode("utf-8"), content=json_content.encode("utf-8"),
@ -335,7 +336,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url, url=options.url,
params=options.params, params=options.params,
headers=options.headers or {}, headers=options.headers or {},
json=options.json_data, json=convert_pydantic_to_json_value(body),
), ),
) )
response = APIResponse( response = APIResponse(
@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url, url=options.url,
params=options.params, params=options.params,
headers=options.headers or {}, headers=options.headers or {},
json=options.json_data, json=convert_pydantic_to_json_value(body),
), ),
) )

View file

@ -5,9 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib import importlib
import inspect import inspect
import logging
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set
from llama_stack import logcat
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
@ -50,8 +50,6 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate, VectorDBsProtocolPrivate,
) )
log = logging.getLogger(__name__)
class InvalidProviderError(Exception): class InvalidProviderError(Exception):
pass pass
@ -115,8 +113,8 @@ async def resolve_impls(
- flatmaps, sorts and resolves the providers in dependency order - flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation - for each API, produces either a (local, passthrough or router) implementation
""" """
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis()) routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = set(x.router_api for x in builtin_automatically_routed_apis()) router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = {} providers_with_specs = {}
@ -127,16 +125,21 @@ async def resolve_impls(
specs = {} specs = {}
for provider in providers: for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
if provider.provider_type not in provider_registry[api]: if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`") raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
if p.deprecation_error: if p.deprecation_error:
log.error(p.deprecation_error, "red", attrs=["bold"]) logcat.error("core", p.deprecation_error)
raise InvalidProviderError(p.deprecation_error) raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning: elif p.deprecation_warning:
log.warning( logcat.warning(
"core",
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
) )
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies] p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
@ -210,10 +213,10 @@ async def resolve_impls(
) )
) )
log.info(f"Resolved {len(sorted_providers)} providers") logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
log.info(f" {api_str} => {provider.provider_id}") logcat.debug("core", f" {api_str} => {provider.provider_id}")
log.info("") logcat.debug("core", "")
impls = {} impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis} inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
@ -350,7 +353,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters) obj_params = set(obj_sig.parameters)
obj_params.discard("self") obj_params.discard("self")
if not (proto_params <= obj_params): if not (proto_params <= obj_params):
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch")) missing_methods.append((name, "signature_mismatch"))
else: else:
# Check if the method is actually implemented in the class # Check if the method is actually implemented in the class

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import copy
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack import logcat
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL, URL,
InterleavedContent, InterleavedContent,
@ -62,12 +64,15 @@ class VectorIORouter(VectorIO):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing VectorIORouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "VectorIORouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "VectorIORouter.shutdown")
pass pass
async def register_vector_db( async def register_vector_db(
@ -78,6 +83,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None,
) -> None: ) -> None:
logcat.debug("core", f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
vector_db_id, vector_db_id,
embedding_model, embedding_model,
@ -92,6 +98,10 @@ class VectorIORouter(VectorIO):
chunks: List[Chunk], chunks: List[Chunk],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
logcat.debug(
"core",
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.id for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks( async def query_chunks(
@ -100,6 +110,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -110,12 +121,15 @@ class InferenceRouter(Inference):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing InferenceRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "InferenceRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "InferenceRouter.shutdown")
pass pass
async def register_model( async def register_model(
@ -126,6 +140,10 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> None: ) -> None:
logcat.debug(
"core",
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
async def chat_completion( async def chat_completion(
@ -141,6 +159,10 @@ class InferenceRouter(Inference):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
logcat.debug(
"core",
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -159,6 +181,7 @@ class InferenceRouter(Inference):
params["tool_prompt_format"] = tool_prompt_format params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params) tool_config = ToolConfig(**params)
tool_config = copy.copy(tool_config)
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id) tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
tools = tools or [] tools = tools or []
@ -201,6 +224,10 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
logcat.debug(
"core",
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -228,6 +255,7 @@ class InferenceRouter(Inference):
output_dimension: Optional[int] = None, output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None, task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
@ -247,12 +275,15 @@ class SafetyRouter(Safety):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing SafetyRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "SafetyRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "SafetyRouter.shutdown")
pass pass
async def register_shield( async def register_shield(
@ -262,6 +293,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield( async def run_shield(
@ -270,6 +302,7 @@ class SafetyRouter(Safety):
messages: List[Message], messages: List[Message],
params: Dict[str, Any] = None, params: Dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield( return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id, shield_id=shield_id,
messages=messages, messages=messages,
@ -282,12 +315,15 @@ class DatasetIORouter(DatasetIO):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing DatasetIORouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "DatasetIORouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "DatasetIORouter.shutdown")
pass pass
async def get_rows_paginated( async def get_rows_paginated(
@ -297,6 +333,7 @@ class DatasetIORouter(DatasetIO):
page_token: Optional[str] = None, page_token: Optional[str] = None,
filter_condition: Optional[str] = None, filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ) -> PaginatedRowsResult:
logcat.debug("core", f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}")
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated( return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=rows_in_page, rows_in_page=rows_in_page,
@ -305,6 +342,7 @@ class DatasetIORouter(DatasetIO):
) )
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows( return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id, dataset_id=dataset_id,
rows=rows, rows=rows,
@ -316,12 +354,15 @@ class ScoringRouter(Scoring):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing ScoringRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "ScoringRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "ScoringRouter.shutdown")
pass pass
async def score_batch( async def score_batch(
@ -330,6 +371,7 @@ class ScoringRouter(Scoring):
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
@ -350,6 +392,7 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
logcat.debug("core", f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
@ -367,12 +410,15 @@ class EvalRouter(Eval):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing EvalRouter")
self.routing_table = routing_table self.routing_table = routing_table
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "EvalRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "EvalRouter.shutdown")
pass pass
async def run_eval( async def run_eval(
@ -380,6 +426,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
task_config: BenchmarkConfig, task_config: BenchmarkConfig,
) -> Job: ) -> Job:
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval( return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
task_config=task_config, task_config=task_config,
@ -392,6 +439,7 @@ class EvalRouter(Eval):
scoring_functions: List[str], scoring_functions: List[str],
task_config: BenchmarkConfig, task_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=input_rows, input_rows=input_rows,
@ -404,6 +452,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> Optional[JobStatus]: ) -> Optional[JobStatus]:
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel( async def job_cancel(
@ -411,6 +460,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> None: ) -> None:
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel( await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id, benchmark_id,
job_id, job_id,
@ -421,6 +471,7 @@ class EvalRouter(Eval):
benchmark_id: str, benchmark_id: str,
job_id: str, job_id: str,
) -> EvaluateResponse: ) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result( return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id, benchmark_id,
job_id, job_id,
@ -433,6 +484,7 @@ class ToolRuntimeRouter(ToolRuntime):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table self.routing_table = routing_table
async def query( async def query(
@ -441,6 +493,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str], vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None, query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult: ) -> RAGQueryResult:
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query( return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config content, vector_db_ids, query_config
) )
@ -451,6 +504,10 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
logcat.debug(
"core",
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert( return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens documents, vector_db_id, chunk_size_in_tokens
) )
@ -459,6 +516,7 @@ class ToolRuntimeRouter(ToolRuntime):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
) -> None: ) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter")
self.routing_table = routing_table self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()" # HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -467,12 +525,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None: async def initialize(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.initialize")
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.shutdown")
pass pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any: async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool( return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name, tool_name=tool_name,
kwargs=kwargs, kwargs=kwargs,
@ -481,4 +542,5 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> List[ToolDef]:
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -318,14 +318,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
) )
model = await self.get_object_by_identifier("model", embedding_model) model = await self.get_object_by_identifier("model", embedding_model)
if model is None: if model is None:
if embedding_model == "all-MiniLM-L6-v2": raise ValueError(f"Model {embedding_model} not found")
raise ValueError(
"Embeddings are now served via Inference providers. "
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
)
else:
raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding: if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model") raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata: if "embedding_dimension" not in model.metadata:

View file

@ -26,9 +26,9 @@ from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack import logcat
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
@ -55,7 +55,7 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logger = logging.getLogger(__name__) logcat.init()
def warn_with_traceback(message, category, filename, lineno, file=None, line=None): def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -142,23 +142,23 @@ def handle_signal(app, signum, _) -> None:
not block the current execution. not block the current execution.
""" """
signame = signal.Signals(signum).name signame = signal.Signals(signum).name
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...") logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown(): async def shutdown():
try: try:
# Gracefully shut down implementations # Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values(): for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__ impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name) logcat.info("server", f"Shutting down {impl_name}")
try: try:
if hasattr(impl, "shutdown"): if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5) await asyncio.wait_for(impl.shutdown(), timeout=5)
else: else:
logger.warning("No shutdown method for %s", impl_name) logcat.warning("server", f"No shutdown method for {impl_name}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) logcat.exception("server", f"Shutdown timeout for {impl_name}")
except Exception as e: except Exception as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e}) logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
# Gather all running tasks # Gather all running tasks
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -172,7 +172,7 @@ def handle_signal(app, signum, _) -> None:
try: try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10) await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.exception("Timeout while waiting for tasks to finish") logcat.exception("server", "Timeout while waiting for tasks to finish")
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
finally: finally:
@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None:
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
logger.info("Starting up") logcat.info("server", "Starting up")
yield yield
logger.info("Shutting down") logcat.info("server", "Shutting down")
for impl in app.__llama_stack_impls__.values(): for impl in app.__llama_stack_impls__.values():
await impl.shutdown() await impl.shutdown()
@ -209,10 +209,10 @@ async def sse_generator(event_gen):
yield create_sse_event(item) yield create_sse_event(item)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
except asyncio.CancelledError: except asyncio.CancelledError:
print("Generator cancelled") logcat.info("server", "Generator cancelled")
await event_gen.aclose() await event_gen.aclose()
except Exception as e: except Exception as e:
traceback.print_exception(e) logcat.exception("server", "Error in sse_generator")
yield create_sse_event( yield create_sse_event(
{ {
"error": { "error": {
@ -234,7 +234,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
value = func(**kwargs) value = func(**kwargs)
return await maybe_await(value) return await maybe_await(value)
except Exception as e: except Exception as e:
traceback.print_exception(e) logcat.exception("server", f"Error in {func.__name__}")
raise translate_exception(e) from e raise translate_exception(e) from e
sig = inspect.signature(func) sig = inspect.signature(func)
@ -313,6 +313,8 @@ class ClientVersionMiddleware:
def main(): def main():
logcat.init()
"""Start the LlamaStack server.""" """Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.") parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument( parser.add_argument(
@ -352,10 +354,10 @@ def main():
for env_pair in args.env: for env_pair in args.env:
try: try:
key, value = validate_env_pair(env_pair) key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}") logcat.info("server", f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value os.environ[key] = value
except ValueError as e: except ValueError as e:
logger.error(f"Error: {str(e)}") logcat.error("server", f"Error: {str(e)}")
sys.exit(1) sys.exit(1)
if args.yaml_config: if args.yaml_config:
@ -363,12 +365,12 @@ def main():
config_file = Path(args.yaml_config) config_file = Path(args.yaml_config)
if not config_file.exists(): if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist") raise ValueError(f"Config file {config_file} does not exist")
logger.info(f"Using config file: {config_file}") logcat.info("server", f"Using config file: {config_file}")
elif args.template: elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml" config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists(): if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist") raise ValueError(f"Template {args.template} does not exist")
logger.info(f"Using template {args.template} config file: {config_file}") logcat.info("server", f"Using template {args.template} config file: {config_file}")
else: else:
raise ValueError("Either --yaml-config or --template must be provided") raise ValueError("Either --yaml-config or --template must be provided")
@ -376,9 +378,10 @@ def main():
config = replace_env_vars(yaml.safe_load(fp)) config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config) config = StackRunConfig(**config)
logger.info("Run configuration:") logcat.info("server", "Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump()) safe_config = redact_sensitive_fields(config.model_dump())
logger.info(yaml.dump(safe_config, indent=2)) for log_line in yaml.dump(safe_config, indent=2).split("\n"):
logcat.info("server", log_line)
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware) app.add_middleware(TracingMiddleware)
@ -388,7 +391,7 @@ def main():
try: try:
impls = asyncio.run(construct_stack(config)) impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e: except InvalidProviderError as e:
logger.error(f"Error: {str(e)}") logcat.error("server", f"Error: {str(e)}")
sys.exit(1) sys.exit(1)
if Api.telemetry in impls: if Api.telemetry in impls:
@ -433,11 +436,8 @@ def main():
) )
) )
logger.info(f"Serving API {api_str}") logcat.debug("server", f"Serving API {api_str}")
for endpoint in endpoints:
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
print("")
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_signal, app)) signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
@ -463,10 +463,10 @@ def main():
"ssl_keyfile": keyfile, "ssl_keyfile": keyfile,
"ssl_certfile": certfile, "ssl_certfile": certfile,
} }
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0" listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
logger.info(f"Listening on {listen_host}:{port}") logcat.info("server", f"Listening on {listen_host}:{port}")
uvicorn_config = { uvicorn_config = {
"app": app, "app": app,

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import importlib.resources import importlib.resources
import logging
import os import os
import re import re
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@ -13,6 +12,7 @@ from typing import Any, Dict, Optional
import yaml import yaml
from termcolor import colored from termcolor import colored
from llama_stack import logcat
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
@ -39,8 +39,6 @@ from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
class LlamaStack( class LlamaStack(
VectorDBs, VectorDBs,
@ -101,12 +99,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process: for obj in objects_to_process:
log.info( logcat.debug(
"core",
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}", f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
) )
log.info("")
class EnvVarError(Exception): class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""): def __init__(self, var_name: str, path: str = ""):
@ -155,18 +152,34 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
return result return result
elif isinstance(config, str): elif isinstance(config, str):
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}" # Updated pattern to support both default values (:) and conditional values (+)
pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}"
def get_env_var(match): def get_env_var(match):
env_var = match.group(1) env_var = match.group(1)
default_val = match.group(2) operator = match.group(2) # ':' for default, '+' for conditional
value_expr = match.group(3)
value = os.environ.get(env_var) env_value = os.environ.get(env_var)
if not value:
if default_val is None: if operator == ":": # Default value syntax: ${env.FOO:default}
raise EnvVarError(env_var, path) if not env_value:
if value_expr is None:
raise EnvVarError(env_var, path)
else:
value = value_expr
else: else:
value = default_val value = env_value
elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set}
if env_value:
value = value_expr
else:
# If env var is not set, return empty string for the conditional case
value = ""
else: # No operator case: ${env.FOO}
if not env_value:
raise EnvVarError(env_var, path)
value = env_value
# expand "~" from the values # expand "~" from the values
return os.path.expanduser(value) return os.path.expanduser(value)

View file

@ -98,15 +98,20 @@ case "$env_type" in
*) *)
esac esac
set -x
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
$PYTHON_BINARY -m llama_stack.distribution.server.server \ $PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \ --yaml-config "$yaml_config" \
--port "$port" \ --port "$port" \
$env_vars \ $env_vars \
$other_args $other_args
elif [[ "$env_type" == "container" ]]; then elif [[ "$env_type" == "container" ]]; then
# Check if container command is available
if ! is_command_available $CONTAINER_BINARY; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1
fi
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
# Disable SELinux labels # Disable SELinux labels
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable" CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
@ -136,6 +141,8 @@ elif [[ "$env_type" == "container" ]]; then
version_tag=$(curl -s $URL | jq -r '.info.version') version_tag=$(curl -s $URL | jq -r '.info.version')
fi fi
set -x
$CONTAINER_BINARY run $CONTAINER_OPTS -it \ $CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \ -p $port:$port \
$env_vars \ $env_vars \

View file

@ -1,72 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <venv_path> <yaml_config> <port> <script_args...>"
exit 1
fi
venv_path="$1"
shift
yaml_config="$1"
shift
port="$1"
shift
# Initialize env_vars as an empty array
env_vars=""
other_args=""
# Process environment variables from --env arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac
done
echo "Using virtual environment: $venv_path"
# Activate virtual environment
if [ ! -d "$venv_path" ]; then
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
exit 1
fi
source "$venv_path/bin/activate"
set -x
python -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
--port "$port" \
$env_vars \
$other_args

View file

@ -134,7 +134,7 @@ def rag_chat_page():
dict( dict(
name="builtin::rag/knowledge_search", name="builtin::rag/knowledge_search",
args={ args={
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs], "vector_db_ids": list(selected_vector_dbs),
}, },
) )
], ],

View file

@ -46,7 +46,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode()) conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
envs = conda_env_info["envs"] envs = conda_env_info["envs"]
for envpath in envs: for envpath in envs:
if envpath.endswith(env_name): if os.path.basename(envpath) == env_name:
return envpath return envpath
return None return None

204
llama_stack/logcat.py Normal file
View file

@ -0,0 +1,204 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Category-based logging utility for llama-stack.
This module provides a wrapper over the standard Python logging module that supports
categorized logging with environment variable control.
Usage:
from llama_stack import logcat
logcat.info("server", "Starting up...")
logcat.debug("inference", "Processing request...")
Environment variable:
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
Example: "server=debug;inference=warning"
"""
import datetime
import logging
import os
from typing import Dict
# ANSI color codes for terminal output
COLORS = {
"RESET": "\033[0m",
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[35m", # Magenta
"DIM": "\033[2m", # Dimmed text
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
}
# Static list of valid categories representing various parts of the Llama Stack
# server codebase
CATEGORIES = [
"core",
"server",
"router",
"inference",
"agents",
"safety",
"eval",
"tools",
"client",
]
_logger = logging.getLogger("llama_stack")
_logger.propagate = False
_default_level = logging.INFO
# Category-level mapping (can be modified by environment variables)
_category_levels: Dict[str, int] = {}
class TerminalStreamHandler(logging.StreamHandler):
def __init__(self, stream=None):
super().__init__(stream)
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
def format(self, record):
record.is_tty = self.is_tty
return super().format(record)
class ColoredFormatter(logging.Formatter):
"""Custom formatter with colors and fixed-width level names"""
def format(self, record):
levelname = record.levelname
# Use only time with milliseconds, not date
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
file_info = f"{record.filename}:{record.lineno}"
# Get category from extra if available
category = getattr(record, "category", None)
msg = record.getMessage()
if getattr(record, "is_tty", False):
color = COLORS.get(levelname, COLORS["RESET"])
if category:
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
f"{file_info:<20} {category_formatted}{msg}"
)
else:
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
f"{file_info:<20} {msg}"
)
else:
if category:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
else:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
return formatted_msg
def init(default_level: int = logging.INFO) -> None:
global _default_level, _category_levels, _logger
_default_level = default_level
_logger.setLevel(logging.DEBUG)
_logger.handlers = [] # Clear existing handlers
# Add our custom handler with the colored formatter
handler = TerminalStreamHandler()
formatter = ColoredFormatter()
handler.setFormatter(formatter)
_logger.addHandler(handler)
for category in CATEGORIES:
_category_levels[category] = default_level
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().lower()
level_value = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"warn": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}.get(level)
if level_value is None:
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
continue
if category == "all":
for cat in CATEGORIES:
_category_levels[cat] = level_value
else:
if category in CATEGORIES:
_category_levels[category] = level_value
else:
_logger.warning(f"Unknown logging category: {category}")
except ValueError:
_logger.warning(f"Invalid logging configuration: {pair}")
def _should_log(level: int, category: str) -> bool:
category = category.lower()
if category not in _category_levels:
return False
category_level = _category_levels[category]
return level >= category_level
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
if _should_log(level, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
def debug(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
def info(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.INFO, "info", category, msg, *args, **kwargs)
def warning(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
def warn(category: str, msg: str, *args, **kwargs) -> None:
warning(category, msg, *args, **kwargs)
def error(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
def critical(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
def exception(category: str, msg: str, *args, **kwargs) -> None:
if _should_log(logging.ERROR, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
_logger.exception(msg, *args, stacklevel=2, **kwargs)

View file

@ -11,16 +11,128 @@
# top-level folder for each specific model found within the models/ directory at # top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree. # the top-level of this source tree.
import base64
from enum import Enum from enum import Enum
from typing import Any, Dict, Literal, Optional, Union from io import BytesIO
from typing import Any, Dict, List, Literal, Optional, Union
# import all for backwards compatibility from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from llama_models.datatypes import * # noqa: F403
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema from llama_stack.schema_utils import json_schema_type, register_schema
# The goal is that these set of types are relevant for all Llama models.
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
# the llama3 series of models.
class Role(Enum):
system = "system"
user = "user"
assistant = "assistant"
tool = "tool"
class BuiltinTool(Enum):
brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha"
photogen = "photogen"
code_interpreter = "code_interpreter"
Primitive = Union[str, int, float, bool, None]
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
arguments: Dict[str, RecursiveType]
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class ToolPromptFormat(Enum):
"""Prompt format for calling custom / zero shot tools.
:cvar json: JSON format for calling tools. It takes the form:
{
"type": "function",
"function" : {
"name": "function_name",
"description": "function_description",
"parameters": {...}
}
}
:cvar function_tag: Function tag format, pseudo-XML. This looks like:
<function=function_name>(parameters)</function>
:cvar python_list: Python list. The output is a valid Python expression that can be
evaluated to a list. Each element in the list is a function call. Example:
["function_name(param1, param2)", "function_name(param1, param2)"]
"""
json = "json"
function_tag = "function_tag"
python_list = "python_list"
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
class RawMediaItem(BaseModel):
type: Literal["image"] = "image"
data: bytes | BytesIO
model_config = ConfigDict(arbitrary_types_allowed=True)
@field_serializer("data")
def serialize_data(self, data: Optional[bytes], _info):
if data is None:
return None
return base64.b64encode(data).decode("utf-8")
@field_validator("data", mode="before")
@classmethod
def validate_data(cls, v):
if isinstance(v, str):
return base64.b64decode(v)
return v
class RawTextItem(BaseModel):
type: Literal["text"] = "text"
text: str
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
RawContent = str | RawContentItem | List[RawContentItem]
class RawMessage(BaseModel):
role: Literal["user"] | Literal["system"] | Literal["tool"] | Literal["assistant"]
content: RawContent
# This is for RAG but likely should be absorbed into content
context: Optional[RawContent] = None
# These are for the output message coming from the assistant
stop_reason: Optional[StopReason] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
register_schema(ToolCall) register_schema(ToolCall)

View file

@ -1,15 +1,5 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
VERSION="$1"
set -euo pipefail
set -x
uv pip install -U --extra-index-url https://test.pypi.org/simple \
llama-stack==$VERSION llama-models==$VERSION llama-stack-client==$VERSION

View file

@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class QuantizationScheme(Enum):
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
@dataclass
class QuantizationArgs:
scheme: Optional[QuantizationScheme] = None
group_size: Optional[int] = None
spinquant: bool = False
def __init__(self, **kwargs):
for k, v in kwargs.items():
if k == "scheme":
setattr(self, k, QuantizationScheme(v))
else:
if hasattr(self, k):
setattr(self, k, v)
@dataclass
class LoRAArgs:
rank: int
scale: float
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
rope_theta: float = 500000
use_scaled_rope: bool = False
max_batch_size: int = 32
max_seq_len: int = 2048
# vision model params
vision_chunk_size: int = -1 # image resolution for image models
vision_max_num_chunks: int = 4
vision_num_cross_attention_layers: int = -1
quantization_args: Optional[QuantizationArgs] = None
lora_args: Optional[LoRAArgs] = None
def __init__(self, **kwargs):
for k, v in kwargs.items():
if k == "lora_args":
setattr(self, k, LoRAArgs(**v))
elif k == "quantization_args":
setattr(self, k, QuantizationArgs(**v))
else:
if hasattr(self, k):
setattr(self, k, v)
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
assert self.n_kv_heads <= self.n_heads
assert self.n_heads % self.n_kv_heads == 0
assert self.dim % self.n_heads == 0

View file

@ -0,0 +1,282 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import io
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from PIL import Image as PIL_Image
from llama_stack.models.llama.datatypes import (
BuiltinTool,
RawContent,
RawMediaItem,
RawMessage,
RawTextItem,
Role,
StopReason,
ToolCall,
ToolPromptFormat,
)
from .tokenizer import Tokenizer
from .tool_utils import ToolUtils
@dataclass
class VisionInput:
mask: List[List[int]]
images: List[PIL_Image.Image]
@dataclass
class LLMInput:
tokens: List[int]
vision: Optional[VisionInput] = None
def role_str(role: Role) -> str:
role_strs = {
Role.user: "user",
Role.system: "system",
Role.tool: "ipython", # special
Role.assistant: "assistant",
}
return role_strs[role]
class ChatFormat:
possible_headers: Dict[Role, str]
def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
def _encode_header(self, role: str) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
def encode_content(self, content: RawContent) -> LLMInput:
tokens, images = self._encode_content(content, bos=True)
return self._model_input_from_tokens_images(tokens, images)
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
tokens = []
images = []
added_bos = False
def _process(c):
nonlocal added_bos, bos
if isinstance(c, str) or isinstance(c, RawTextItem):
if isinstance(c, RawTextItem):
c = c.text
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
added_bos = True
elif isinstance(c, RawMediaItem):
bos = False if added_bos else bos
if bos:
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
added_bos = True
tokens.append(self.vision_token)
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
image = image.convert("RGB")
images.append(image)
if isinstance(content, list):
for c in content:
_process(c)
else:
_process(content)
return tokens, images
def encode_message(
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
) -> Tuple[List[int], List[PIL_Image.Image]]:
tokens = self._encode_header(message.role)
images = []
def _process_content(c):
toks, imgs = self._encode_content(c)
tokens.extend(toks)
images.extend(imgs)
if (
message.role == "assistant"
and len(message.tool_calls) > 0
and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter
):
tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])
_process_content(message.content)
if message.role == "user" and message.context is not None:
# This is RAG context; why is it here in the chat format? I don't think
# this is needed and can be moved upwards
_process_content("\n\n")
_process_content(message.context)
if message.role == "assistant":
for t in message.tool_calls:
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
_process_content(content)
eom = False
if message.role == "assistant":
eom = message.stop_reason == StopReason.end_of_message
tokens.append(self.tokenizer.special_tokens["<|eom_id|>" if eom else "<|eot_id|>"])
return tokens, images
def encode_dialog_prompt(
self,
messages: List[RawMessage],
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> LLMInput:
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
tokens = []
images = []
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
for message in messages:
toks, imgs = self.encode_message(message, tool_prompt_format)
tokens.extend(toks)
images.extend(imgs)
# Add the start of an assistant message for the model to complete.
tokens.extend(self._encode_header("assistant"))
return self._model_input_from_tokens_images(tokens, images)
# TODO(this should be generic, not only for assistant messages)
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
content = self.tokenizer.decode(tokens)
return self.decode_assistant_message_from_content(content, stop_reason)
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
content = content.strip(" ")
header_str = self.possible_headers[Role.assistant]
if content.startswith(header_str):
content = content[len(header_str) :]
ipython = content.startswith("<|python_tag|>")
if ipython:
content = content[len("<|python_tag|>") :]
if content.endswith("<|eot_id|>"):
content = content[: -len("<|eot_id|>")]
stop_reason = StopReason.end_of_turn
elif content.endswith("<|eom_id|>"):
content = content[: -len("<|eom_id|>")]
stop_reason = StopReason.end_of_message
tool_name = None
tool_arguments = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None:
tool_name, tool_arguments = custom_tool_info
# Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
tool_name, query = builtin_tool_info
tool_arguments = {
"query": query,
}
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = []
if tool_name is not None and tool_arguments is not None:
call_id = str(uuid.uuid4())
tool_calls.append(
ToolCall(
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
)
)
content = ""
return RawMessage(
role="assistant",
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
vision_input = None
if len(images) > 0:
vision_input = VisionInput(
mask=create_vision_mask(tokens, self.vision_token),
images=images,
)
return LLMInput(
tokens=[128256 if token == self.vision_token else token for token in tokens],
vision=vision_input,
)
def create_vision_mask(
tokens: List[int],
vision_token: int,
) -> List[List[int]]:
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
if len(vision_token_locations) == 0:
return []
if len(vision_token_locations) == 1:
# only one image present, unmask until end of sequence
return [[vision_token_locations[0], -1]]
vision_masks = [
[loc1, loc2] for loc1, loc2 in zip(vision_token_locations[:-1], vision_token_locations[1:], strict=False)
]
# last image will attend to all subsequent text
vision_masks.append([vision_token_locations[-1], len(tokens)])
# if there are two or more consecutive vision tokens,
# they should all attend to all subsequent
# text present
last_mask_end = vision_masks[-1][1]
for vision_mask in vision_masks[::-1]:
if vision_mask[0] == vision_mask[1] - 1:
vision_mask[1] = last_mask_end
last_mask_end = vision_mask[1]
return vision_masks

View file

@ -14,20 +14,19 @@
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
from llama_models.datatypes import ( from termcolor import colored
from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
RawMessage, RawMessage,
StopReason, StopReason,
ToolCall, ToolCall,
ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from termcolor import colored
from llama_stack.models.llama.datatypes import ToolDefinition
from . import template_data from . import template_data
from .chat_format import ChatFormat
from .prompt_templates import ( from .prompt_templates import (
BuiltinToolGenerator, BuiltinToolGenerator,
FunctionTagCustomToolGenerator, FunctionTagCustomToolGenerator,
@ -35,6 +34,7 @@ from .prompt_templates import (
SystemDefaultGenerator, SystemDefaultGenerator,
ToolResponseGenerator, ToolResponseGenerator,
) )
from .tokenizer import Tokenizer
THIS_DIR = Path(__file__).parent THIS_DIR = Path(__file__).parent

View file

@ -0,0 +1,311 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math
from typing import Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
)
from torch import nn
from .args import ModelArgs
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * torch.pi / freqs
new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
return torch.where(
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
(1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
new_freqs,
)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
params.use_scaled_rope,
)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
# https://github.com/pytorch/pytorch/issues/100005
# torch.triu is buggy when the device is mps: filled values are
# nan instead of 0.
if mask.device.type == torch.device("mps").type:
mask = torch.nan_to_num(mask, nan=0.0)
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

View file

@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and its affiliates.
import math
from logging import getLogger
import torch
import torch.nn.functional as F
from .utils import get_negative_inf_value, to_2tuple
logger = getLogger()
def resize_local_position_embedding(orig_pos_embed, grid_size):
"""
Resize position embedding for vision encoder.
Original position embedding is [n_tiles * n_tiles + 1, dim]
New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
"""
new_grid_size = to_2tuple(grid_size)
orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
new_pos_emb_tok, new_pos_emb_img = (
orig_pos_embed[:1],
orig_pos_embed[1:],
)
logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
new_pos_emb_img = F.interpolate(
new_pos_emb_img,
size=new_grid_size,
mode="bilinear",
align_corners=True,
)
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
return new_pos_embed
def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
"""
Takes a local position embedding for vision encoder and uses it
to initialize the global position embedding.
Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
"""
pos_embed = pos_and_cls_embed[1:]
cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
grid_size = to_2tuple(grid_size)
new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
new_pos_emb_img = F.interpolate(
new_pos_emb_img,
size=new_grid_size,
mode="bilinear",
align_corners=True,
)
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
return pos_and_cls_embed
def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
"""
Takes a global position embedding for vision encoder and resizes it to new size.
Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
"""
# first remove cls token
pos_embed = pos_and_cls_embed[:, :, 1:]
cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
xs_old, ys_old, ntok, dim = pos_embed.shape
old_grid_size = int(math.sqrt(ntok))
# move to correct form for interpolation
pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
pos_embed = pos_embed.unsqueeze(0)
# interpolate
new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
pos_embed = pos_embed.permute(0, 3, 1, 2)
pos_embed_resized = F.interpolate(
pos_embed,
size=new_size,
mode="bilinear",
align_corners=True,
)
pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
# move it back in place
pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
# interpolate cls token
cls_embed = cls_embed.permute(2, 3, 0, 1)
cls_embed_resized = F.interpolate(
cls_embed,
size=(x_scale, y_scale),
mode="bilinear",
align_corners=True,
)
cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
# add cls token back in
pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
return pos_and_cls_embed
def build_encoder_attention_mask(
x: torch.Tensor,
ar: torch.Tensor,
ntok: int,
num_chunks: int,
n_heads: int,
):
"""
Build vision encoder attention mask that omits padding tokens.
"""
masks = []
for arx in ar:
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
mask_i[: arx[0] * arx[1], :ntok] = 0
mask_i = mask_i.view(num_chunks * x.shape[2], -1)
mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
mask_i = mask_i.unsqueeze(0)
masks.append(mask_i)
masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1)
return masks
def expand_num_tokens_to_mult8(x):
num_pad_tokens = 8 - (x.shape[-2] % 8)
if num_pad_tokens == 0:
return x, 0
else:
return (
torch.cat(
[
x,
torch.zeros(
(x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
dtype=x.dtype,
device=x.device,
),
],
dim=-2,
),
num_pad_tokens,
)
def contract_num_tokens_from_mult8(x, num_pad_tokens):
if num_pad_tokens == 0:
return x
return x[:, :, :-num_pad_tokens]

View file

@ -0,0 +1,408 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import math
from collections import defaultdict
from logging import getLogger
from typing import Any, Optional, Set, Tuple
import torch
import torchvision.transforms as tv
from PIL import Image
from torchvision.transforms import functional as F
IMAGE_RES = 224
logger = getLogger()
class VariableSizeImageTransform(object):
"""
This class accepts images of any size and dynamically resize, pads and chunks it
based on the image aspect ratio and the number of image chunks we allow.
The algorithm will NOT distort the image fit a certain aspect ratio, because
that leads to a significant degradation in image quality.
It can be summarized in 6 steps:
1. Find all possible canvas combinations of max_num_chunks;
2. Find the best canvas to fit the image;
3. Resize without distortion
4. Pad
5. Normalize
6. Chunk
For example, if an input image is of size 300x800, patch_size of 224,
and max_num_chunks = 8, it will find the closest aspect ratio that
is allowed within 8 image chunks, with some restrictions.
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
giving a total of 8 chunks.
If resize_to_max_canvas, the image will be resized (without distortion),
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
where we maintain the original aspect ratio and pad with zeros value for the rest.
This approach minimizes the amount of padding required for any arbitrary resolution.
However, if limit_upscaling_to_patch_size is set to True,
the upscaling will be limited to the patch size. In the example above,
the image would remain 300x800 (no upscaling), and then padded to 448:896.
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
patches are coming from the resizing and chunking.
"""
def __init__(self, size: int = IMAGE_RES) -> None:
self.size = size
logger.info(f"VariableSizeImageTransform size: {self.size}")
self.to_tensor = tv.ToTensor()
self._mean = (0.48145466, 0.4578275, 0.40821073)
self._std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = tv.Normalize(
mean=self._mean,
std=self._std,
inplace=True,
)
self.resample = tv.InterpolationMode.BILINEAR
@staticmethod
def get_factors(n: int) -> Set[int]:
"""
Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
Args:
n (int): The number to find factors for.
Returns:
set: A set containing all factors of the number.
"""
factors_set = set()
for i in range(1, int(n**0.5) + 1):
if n % i == 0:
factors_set.add(i)
factors_set.add(n // i)
return factors_set
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
"""
Computes all of the allowed resoltuions for a fixed number of chunks
and patch_size. Useful for when dividing an image into chunks.
Args:
max_num_chunks (int): Maximum number of chunks for processing.
patch_size (int): Size of the side of the patch.
Returns:
torch.Tensor: List of possible resolutions as tuples (height, width).
Example:
>>> max_num_chunks = 5
>>> patch_size = 224
>>> find_supported_resolutions(max_num_chunks, patch_size)
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
(672, 224), (224, 448), (448, 224)])
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
{
0.25: [(1, 4)],
1.0: [(2, 2), (1, 1)],
4.0: [(4, 1)],
0.33: [(1, 3)],
3.0: [(3, 1)],
0.5: [(1, 2)],
2.0: [(2, 1)]
}
and return the resolutions multiplied by the patch_size:
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
"""
asp_dict = defaultdict(list)
for chunk_size in range(max_num_chunks, 0, -1):
_factors = sorted(self.get_factors(chunk_size))
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
for height, width in _asp_ratios:
ratio_float = height / width
asp_dict[ratio_float].append((height, width))
# get the resolutions multiplied by the patch_size
possible_resolutions = []
for value in asp_dict.values():
for height, depth in value:
possible_resolutions.append((height * patch_size, depth * patch_size))
return possible_resolutions
@staticmethod
def get_max_res_without_distortion(
image_size: Tuple[int, int],
target_size: Tuple[int, int],
) -> Tuple[int, int]:
"""
Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution.
Args:
image_size (Tuple[int, int]): The original resolution of the image (height, width).
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
Returns:
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
Example:
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
(134, 200)
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
(450, 338)
"""
original_width, original_height = image_size
target_width, target_height = target_size
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.floor(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.floor(original_width * scale_h), target_width)
return new_width, new_height
def _pad(self, image: Image.Image, target_size) -> Image.Image:
new_width, new_height = target_size
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
new_im.paste(image)
return new_im
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
# Split image into number of required tiles (width x height)
num_channels, height, width = image.size()
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
# Permute dimensions to reorder the axes
image = image.permute(1, 3, 0, 2, 4).contiguous()
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
return image
def resize_without_distortion(
self,
image: torch.Tensor,
target_size: Tuple[int, int],
max_upscaling_size: Optional[int],
) -> torch.Tensor:
"""
Used to resize an image to target_resolution, without distortion.
If target_size requires upscaling the image, the user can set max_upscaling_size to
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
modifying target_size works as a boundary for the image's largest side.
Args:
resample (str): Resampling method used when resizing images.
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
max_upscaling_size (int): The maximum size to upscale the image to.
If None, there is no limit.
Examples:
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 600
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(600, 300) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 600
>>> image_size = (2000, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 100) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 2000
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 500) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = None
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 500) # new_size_without_distortion
"""
image_width, image_height = image.size
image_size = (image_width, image_height)
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
if max_upscaling_size is not None:
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
target_size = (new_target_width, new_target_height)
# resize to target_size while preserving aspect ratio
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
image = F.resize(
image,
(new_size_without_distortion[1], new_size_without_distortion[0]),
interpolation=self.resample,
)
return image
def get_best_fit(
self,
image_size: Tuple[int, int],
possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
"""
Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to.
For each possible resolution, calculates the scaling factors for
width and height, and selects the smallest one, which is the limiting side.
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
If upscaling is possible (any of the scaling factors is greater than 1),
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
reduce downscaling as much as possible.
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
has more padding.
Args:
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
row represents a possible resolution (height, width).
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
Returns:
List[int]: The best resolution [height, width] for the given image.
Example:
>>> image_size = (200, 300)
>>> possible_resolutions = torch.tensor([[224, 672],
... [672, 224],
... [224, 448],
... [448, 224],
... [224, 224]])
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
[224, 448]
We have:
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
Only one of the scales > 1:
upscaling_possible = tensor([1.1200, 1.1200])
smallest_rescale = tensor(1.1200)
So we pick the resolution with the smallest smallest area:
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
optimal_canvas = tensor([224, 448])
"""
original_width, original_height = image_size
# get all possible resolutions heights/widths
target_widths, target_heights = (
possible_resolutions[:, 0],
possible_resolutions[:, 1],
)
# get scaling factors to resize the image without distortion
scale_w = target_widths / original_width
scale_h = target_heights / original_height
# get the min scale between width and height (limiting side -> no distortion)
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
# filter only scales that allow upscaling
upscaling_options = scales[scales >= 1]
if len(upscaling_options) > 0:
if resize_to_max_canvas:
selected_scale = torch.max(upscaling_options)
else:
selected_scale = torch.min(upscaling_options)
else:
# no upscaling possible,
# get the minimum downscaling (max scale for scales<1)
downscaling_options = scales[scales < 1]
selected_scale = torch.max(downscaling_options)
# get all resolutions that support this scaling factor,
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
chosen_canvas = possible_resolutions[scales == selected_scale]
# if there are multiple resolutions,
# get the one with minimum area to reduce padding
if len(chosen_canvas) > 1:
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
optimal_idx = torch.argmin(areas)
optimal_canvas = chosen_canvas[optimal_idx]
else:
optimal_canvas = chosen_canvas[0]
return tuple(optimal_canvas.tolist())
def __call__(
self,
image: Image.Image,
max_num_chunks: int,
normalize_img: bool = True,
resize_to_max_canvas: bool = False,
) -> Tuple[Any, Any]:
"""
Args:
image (PIL.Image): Image to be resized.
max_num_chunks (int): Maximum number of chunks to split the image into.
normalize_img (bool): Whether to normalize the image.
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
If True, picks the canvas the allows the largest resizing without distortion.
If False, downsample as little as possible, including no resizing at all,
but never upsample, unless the image is smaller than the patch size.
"""
assert max_num_chunks > 0
assert isinstance(image, Image.Image), type(image)
w, h = image.size
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
possible_resolutions = torch.tensor(possible_resolutions)
best_resolution = self.get_best_fit(
image_size=(w, h),
possible_resolutions=possible_resolutions,
resize_to_max_canvas=resize_to_max_canvas,
)
max_upscaling_size = None if resize_to_max_canvas else self.size
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
image = self._pad(image, best_resolution)
image = self.to_tensor(image)
if normalize_img:
image = self.normalize(image)
ratio_w, ratio_h = (
best_resolution[0] // self.size,
best_resolution[1] // self.size,
)
image = self._split(image, ratio_w, ratio_h) # type: ignore
ar = (ratio_h, ratio_w)
return image, ar

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import collections
import torch
def get_negative_inf_value(dtype):
return torch.finfo(dtype).min
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)

View file

@ -15,11 +15,8 @@ import textwrap
from datetime import datetime from datetime import datetime
from typing import Any, List, Optional from typing import Any, List, Optional
from llama_models.datatypes import (
BuiltinTool,
)
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolDefinition, ToolDefinition,
ToolParamDefinition, ToolParamDefinition,
) )
@ -226,10 +223,9 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
DEFAULT_PROMPT = textwrap.dedent( DEFAULT_PROMPT = textwrap.dedent(
""" """
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
You are an expert in composing functions. You are given a question and a set of possible functions. You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose. Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
{{ function_description }} {{ function_description }}
""".strip("\n") """.strip("\n")

View file

@ -11,7 +11,7 @@
# top-level folder for each specific model found within the models/ directory at # top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree. # the top-level of this source tree.
from llama_models.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
StopReason, StopReason,
ToolCall, ToolCall,

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from logging import getLogger
from pathlib import Path
from typing import (
AbstractSet,
Collection,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)
import tiktoken
from tiktoken.load import load_tiktoken_bpe
logger = getLogger(__name__)
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
# https://github.com/openai/tiktoken/issues/195
# Here we iterate over subsequences and split if we exceed the limit
# of max consecutive non-whitespace or whitespace characters.
MAX_NO_WHITESPACES_CHARS = 25_000
_INSTANCE = None
class Tokenizer:
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
num_reserved_special_tokens = 256
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
@classmethod
def get_instance(cls):
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
return _INSTANCE
def __init__(self, model_path: str):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path
mergeable_ranks = load_tiktoken_bpe(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|finetune_right_pad_id|>",
"<|step_id|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|eom_id|>", # end of message
"<|eot_id|>", # end of turn
"<|python_tag|>",
"<|image|>",
]
reserved_tokens = [
f"<|reserved_special_token_{2 + i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
]
special_tokens = special_tokens + reserved_tokens
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
self.n_words: int = num_base_tokens + len(special_tokens)
# BOS / EOS token IDs
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
self.eot_id: int = self.special_tokens["<|eot_id|>"]
self.eom_id: int = self.special_tokens["<|eom_id|>"]
self.python_tag_id = self.special_tokens["<|python_tag|>"]
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
self.stop_tokens = [
self.eos_id,
self.special_tokens["<|eom_id|>"],
self.special_tokens["<|eot_id|>"],
]
def encode(
self,
s: str,
*,
bos: bool,
eos: bool,
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
"""
Encodes a string into a list of token IDs.
Args:
s (str): The input string to be encoded.
bos (bool): Whether to prepend the beginning-of-sequence token.
eos (bool): Whether to append the end-of-sequence token.
allowed_special ("all"|set[str]): allowed special tokens in string
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
Returns:
list[int]: A list of token IDs.
By default, setting disallowed_special=() encodes a string by ignoring
special tokens. Specifically:
- Setting `disallowed_special` to () will cause all text corresponding
to special tokens to be encoded as natural text (insteading of raising
an error).
- Setting `allowed_special` to "all" will treat all text corresponding
to special tokens to be encoded as special tokens.
"""
if allowed_special is None:
allowed_special = set()
assert type(s) is str
substrs = (
substr
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
for substr in self._split_whitespaces_or_nonwhitespaces(
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
for substr in substrs:
t.extend(
self.model.encode(
substr,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if bos:
t.insert(0, self.bos_id)
if eos:
t.append(self.eos_id)
return t
def decode(self, t: Sequence[int]) -> str:
"""
Decodes a list of token IDs into a string.
Args:
t (List[int]): The list of token IDs to be decoded.
Returns:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
"""
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
consecutive whitespaces or consecutive non-whitespaces.
"""
current_slice_len = 0
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
slice_start = 0
for i in range(len(s)):
is_now_space = s[i].isspace()
if current_slice_is_space ^ is_now_space:
current_slice_len = 1
current_slice_is_space = is_now_space
else:
current_slice_len += 1
if current_slice_len > max_consecutive_slice_len:
yield s[slice_start:i]
slice_start = i
current_slice_len = 1
yield s[slice_start:]

View file

@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import ast
import json
import re
from typing import Optional, Tuple
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
def is_json(s):
try:
parsed = json.loads(s)
# Return True for valid objects and not for ints, strings, etc
return isinstance(parsed, dict)
except json.JSONDecodeError:
return False
return True
def is_valid_python_list(input_string):
"""Check if the input string is a valid Python list of function calls"""
try:
# Try to parse the string
tree = ast.parse(input_string)
# Check if it's a single expression
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
return False
# Check if the expression is a list
expr = tree.body[0].value
if not isinstance(expr, ast.List):
return False
# Check if the list is empty
if len(expr.elts) == 0:
return False
# Check if all elements in the list are function calls
for element in expr.elts:
if not isinstance(element, ast.Call):
return False
# Check if the function call has a valid name
if not isinstance(element.func, ast.Name):
return False
# Check if all arguments are keyword arguments
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
return False
return True
except SyntaxError:
# If parsing fails, it's not a valid Python expression
return False
def parse_python_list_for_function_calls(input_string):
"""
Parse a Python list of function calls and
return a list of tuples containing the function name and arguments
"""
# Parse the string into an AST
tree = ast.parse(input_string)
# Ensure the input is a list
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
raise ValueError("Input must be a list of function calls")
result = []
# Iterate through each function call in the list
for node in tree.body[0].value.elts:
if isinstance(node, ast.Call):
function_name = node.func.id
function_args = {}
# Extract keyword arguments
for keyword in node.keywords:
function_args[keyword.arg] = ast.literal_eval(keyword.value)
result.append((function_name, function_args))
return result
class ToolUtils:
@staticmethod
def is_builtin_tool_call(message_body: str) -> bool:
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
return match is not None
@staticmethod
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
# Find the first match in the text
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
# Check if a match is found and return it
if match:
tool_name = match.group("tool_name")
query = match.group("query")
return tool_name, query
else:
return None
@staticmethod
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
# NOTE: Custom function too calls are still experimental
# Sometimes, response is of the form
# {"type": "function", "name": "function_name", "parameters": {...}
# and some times
# <function=function_name>(parameters)</function>
# Find the first match in the text
match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
if match:
tool_name = match.group("function_name")
query = match.group("args")
try:
return tool_name, json.loads(query.replace("'", '"'))
except Exception as e:
print("Exception while parsing json query for custom tool call", query, e)
return None
elif is_json(message_body):
response = json.loads(message_body)
if ("type" in response and response["type"] == "function") or ("name" in response):
function_name = response["name"]
args = response["parameters"]
return function_name, args
else:
return None
elif is_valid_python_list(message_body):
res = parse_python_list_for_function_calls(message_body)
# FIXME: Enable multiple tool calls
return res[0]
else:
return None
@staticmethod
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
if t.tool_name == BuiltinTool.brave_search:
q = t.arguments["query"]
return f'brave_search.call(query="{q}")'
elif t.tool_name == BuiltinTool.wolfram_alpha:
q = t.arguments["query"]
return f'wolfram_alpha.call(query="{q}")'
elif t.tool_name == BuiltinTool.photogen:
q = t.arguments["query"]
return f'photogen.call(query="{q}")'
elif t.tool_name == BuiltinTool.code_interpreter:
return t.arguments["code"]
else:
fname = t.tool_name
if tool_prompt_format == ToolPromptFormat.json:
return json.dumps(
{
"type": "function",
"name": fname,
"parameters": t.arguments,
}
)
elif tool_prompt_format == ToolPromptFormat.function_tag:
args = json.dumps(t.arguments)
return f"<function={fname}>{args}</function>"
elif tool_prompt_format == ToolPromptFormat.python_list:
def format_value(value: RecursiveType) -> str:
if isinstance(value, str):
return f'"{value}"'
elif isinstance(value, (int, float, bool)) or value is None:
return str(value)
elif isinstance(value, list):
return f"[{', '.join(format_value(v) for v in value)}]"
elif isinstance(value, dict):
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
else:
raise ValueError(f"Unsupported type: {type(value)}")
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
return f"[{fname}({args_str})]"
else:
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")

View file

@ -0,0 +1,358 @@
# Llama 3.1 - Prompt Formats
## Tokens
Here is a list of special tokens that are supported by Llama 3.1:
- `<|begin_of_text|>`: Specifies the start of the prompt
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and ipython]
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
- at the end of a direct interaction between the model and the user
- at the end of multiple interactions between the model and any available tools
This token signals to the executor that the model has finished generating a response.
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
There are 4 different roles that are supported by Llama 3.1
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
- `ipython`: A new role introduced in Llama 3.1. Semantically, this role means "tool". This role is used to mark messages with the output of a tool call when sent back to the model from the executor.
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `ipython` and `user` prompts.
## Llama 3.1 Base Model
Text completion for Llama 3.1 base model uses this format.
##### Input Prompt Format
```
<|begin_of_text|>Color of sky is blue but sometimes can also be
```
##### Model Response Format
```
red, orange, yellow, green, purple, pink, brown, gray, black, white, and even rainbow colors. The color of the sky can change due to various reasons such as time of day, weather conditions, pollution, and atmospheric phenomena.
The color of the sky is primarily blue because of a phenomenon called
```
Note start special tag
## Llama 3.1 Instruct Model
## User and assistant conversation
Here is a regular multi-turn user assistant conversation and how its formatted.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
Answer who are you in the form of jeopardy?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
Here's my response
"What is a helpful assistant?"<|eot_id|>
```
## Tool Calling Formats
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
- Brave Search: Tool call to perform web searches.
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
- Code Interpreter: Enables the model to output python code.
## Builtin Tool Calling
Here is an example of a conversation using brave search
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Tools: brave_search, wolfram_alpha
Cutting Knowledge Date: December 2023
Today Date: 21 September 2024
You are a helpful assistant.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
```
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
- The message body of the assistant response starts with a special tag <|python_tag|>
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
## Builtin Code Interpreter
Here is an actual example of model responding with code
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython<|eot_id|><|start_header_id|>user<|end_header_id|>
Write code to check if number is prime, use that to see if the number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>def is_prime(n):
if n <= 1
return False
for i in range(2, int(n**0.5) + 1):
if n % i == 0:
return False
return True
print(is_prime(7)) # Output: True<|eom_id|>
```
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
## Built-in tools full interaction
Here is a full interaction with the built-in tools including the tool response and the final assistant response.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Tools: brave_search, wolfram_alpha
<|eot_id|><|start_header_id|>user<|end_header_id|>
What is the 100th decimal of pi?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
<|python_tag|>wolfram_alpha.call(query="100th decimal of pi")<|eom_id|><|start_header_id|>ipython<|end_header_id|>
{
"queryresult": {
"success": true,
"inputstring": "100th decimal of pi",
"pods": [
{
"title": "Input interpretation",
"subpods": [
{
"title": "",
"plaintext": "100th digit | π"
}
]
},
{
"title": "Nearby digits",
"subpods": [
{
"title": "",
"plaintext": "...86208998628034825342117067982148086513282306647093..."
}
]
},
{
"title": "Result",
"primary": true,
"subpods": [
{
"title": "",
"plaintext": "7"
}
]
}
]
}
}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
The 100th decimal of pi is 7.<|eot_id|>
```
- Note the `<|python_tag|>` in the assistant response.
- Role is `ipython` for the wolfram alpha response that is passed back to the model.
- Final message from assistant has <|eot_id|> tag.
## Zero shot tool calling
## JSON based tool calling
Llama models can now output custom tool calls from a single message to allow easier tool calling.
The following prompts provide an example of how custom tools can be called from the output of the model.
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Cutting Knowledge Date: December 2023
Today Date: 21 September 2024
You are a helpful assistant.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Answer the user's question by making use of the following functions if needed.
If none of the function can be used, please say so.
Here is a list of functions in JSON format:
{
"type": "function",
"function": {
"name": "trending_songs",
"description": "Returns the trending songs on a Music site",
"parameters": {
"type": "object",
"properties": [
{
"n": {
"type": "object",
"description": "The number of songs to return"
}
},
{
"genre": {
"type": "object",
"description": "The genre of the songs to return"
}
}
],
"required": ["n"]
}
}
}
Return function calls in JSON format.<|eot_id|><|start_header_id|>user<|end_header_id|>
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>{
"type": "function",
"name": "trending_songs",
"parameters": {
"n": "10",
"genre": "all"
}
}<|eom_id|>
```
- JSON format for providing tools needs name, description and parameters
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
- Instructions for tools added as a user message
- Only single tool calls are supported as of now
## Example of a user defined tool calling
## `<function>` based tool calling
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
In this example, we define a custom tool calling format using the `<function>` tag.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Cutting Knowledge Date: December 2023
Today Date: 21 September 2024
You are a helpful assistant.
<|eot_id|><|start_header_id|>user<|end_header_id|>
You have access to the following functions:
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
Think very carefully before calling functions.
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- If looking for real time information use relevant functions before falling back to brave_search
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line<|eot_id|><|start_header_id|>user<|end_header_id|>
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<function=trending_songs>{"n": 10}</function><|eot_id|>
```
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
- Instructions for tools added as a user message
Thank You!

View file

@ -14,7 +14,7 @@
import textwrap import textwrap
from typing import List from typing import List
from llama_models.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
RawMessage, RawMessage,
StopReason, StopReason,

View file

@ -13,7 +13,7 @@
import json import json
import textwrap import textwrap
from llama_models.datatypes import ( from llama_stack.models.llama.datatypes import (
RawMessage, RawMessage,
StopReason, StopReason,
ToolCall, ToolCall,

View file

@ -14,7 +14,7 @@
import textwrap import textwrap
from pathlib import Path from pathlib import Path
from llama_models.datatypes import ( from llama_stack.models.llama.datatypes import (
RawMediaItem, RawMediaItem,
RawMessage, RawMessage,
RawTextItem, RawTextItem,

View file

@ -0,0 +1,286 @@
## User and assistant conversation
Here is a regular multi-turn user assistant conversation and how its formatted.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
I'm an AI assistant, which means I'm a computer program designed to simulate conversation and answer questions to the best of my ability. I'm here to help you with any questions or tasks you may have, and I'll do my best to provide you with accurate and helpful information.
I don't have a personal name, but I'm often referred to as a "virtual assistant" or a "chatbot." I'm a machine learning model, which means I've been trained on a large dataset of text and can generate responses based on patterns and context.
I can help with a wide range of topics, from general knowledge and trivia to more specialized subjects like science, history, and technology. I can also assist with tasks like language translation, text summarization, and even generating creative content like stories or poetry.
So, what can I help you with today?<|eot_id|>
```
##### Notes
This format is unchanged from Llama3.1
## Zero shot function calling
For Llama3.2 1B and 3B instruct models, we are introducing a new format for zero shot function calling.
This new format is designed to be more flexible and powerful than the previous format.
All available functions can be provided in the system message. A key difference is in the format of how the assistant responds with function calls.
It is pythonic in the form of `[func1(params_name=params_value, params_name2=params_value2...), func2(params)]` instead of the `json` or `<function>` tag that were defined in Llama3.1.
Here is an example for the same,
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": [
"city"
],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]<|eot_id|><|start_header_id|>user<|end_header_id|>
What is the weather in SF and Seattle?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
[get_weather(city='San Francisco', metric='celsius'), get_weather(city='Seattle', metric='celsius')]<|eot_id|>
```
##### Notes
- The output supports multiple tool calls natively
- JSON format for defining the functions in the system prompt is similar to Llama3.1
## Zero shot function calling with user message
While the default is to provide all function calls in a system message, in Llama3.2 text models you can also provide information for all the available tools in a user message.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
Here is a list of functions in JSON format that you can invoke:
[
{
"name": "get_user_info",
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
"parameters": {
"type": "dict",
"required": [
"user_id"
],
"properties": {
"user_id": {
"type": "integer",
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
},
"special": {
"type": "string",
"description": "Any special information or parameters that need to be considered while fetching user details.",
"default": "none"
}
}
}
}
]
Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
NO other text MUST be included.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
[get_user_info(user_id=7890, special='black')]<|eot_id|>
```
##### Notes
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
- While builtin tool calls end with a <|eom_id|>, notice the <|eot_id|> for zero shot tool calls.
## Code Interpreter
Code Interpreter continues to work in 3.2 text models similar to Llama 3.1 model family.
Here is an example,
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Cutting Knowledge Date: December 2023
Today Date: 24 September 2024
<|eot_id|><|start_header_id|>user<|end_header_id|>
Write code to check if number is prime. Use it to verify if number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>def is_prime(n):
if n <= 1:
return False
if n == 2:
return True
if n % 2 == 0:
return False
max_divisor = int(n**0.5) + 1
for d in range(3, max_divisor, 2):
if n % d == 0:
return False
return True
print(is_prime(7)) # Output: True<|eom_id|>
```
##### Notes
- Note `Environment: ipython` in the system prompt.
- Note that the response starts with `<|python_tag|>` and ends with `<|eom_id|>`
## Zero shot function calling E2E format
Here is an example of the e2e cycle of tool calls with the model in a muti-step way.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": [
"city"
],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
]<|eot_id|><|start_header_id|>user<|end_header_id|>
What is the weather in SF?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
<|python_tag|>[get_weather(city="San Francisco", metric="celsius")]<|eot_id|><|start_header_id|>ipython<|end_header_id|>
"25 C"<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
The weather in San Francisco is 25 C.<|eot_id|>
```
##### Notes
- The output of the function call is provided back to the model as a tool response ( in json format ).
- Notice `<|start_header_id|>ipython<|end_header_id|>` as the header message preceding the tool response.
- The model finally summarizes the information from the tool response and returns the result to the user.
## Prompt format for base models
For base models (Llama3.2-1B and Llama3.2-3B), the prompt format for a simple completion is as follows
##### Input Prompt Format
```
<|begin_of_text|>The color of the sky is blue but sometimes it can also be
```
##### Model Response Format
```
gray or even purple. The color of the sky can change depending on the time of day, the weather, and the amount of pollution in the air. The color of the sky can also be affected by the presence of dust, smoke, and other particles in the air.
## Step 1: Identify the factors that
```
##### Notes
Same as Llama3.1
Thank You!

View file

@ -0,0 +1,141 @@
## User and assistant conversation
Here is a regular multi-turn user assistant conversation and how its formatted.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
I'm a helpful assistant, here to provide information, answer questions, and assist with tasks to the best of my abilities. I'm a large language model, which means I can understand and respond to natural language inputs, and I'm constantly learning and improving to provide more accurate and helpful responses.
I can help with a wide range of topics, from general knowledge and trivia to more specific areas like science, history, technology, and more. I can also assist with tasks like language translation, text summarization, and even generating creative content like stories or dialogues.
What can I help you with today?<|eot_id|>
```
##### Notes
This format is unchanged from Llama3.1
## User and assistant conversation with Images
This example shows how to pass and image to the model as part of the messages.
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
<|image|>Describe this image in two sentences<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
The image depicts a small dog standing on a skateboard, with its front paws firmly planted on the board and its back paws slightly raised. The dog's fur is predominantly brown and white, with a distinctive black stripe running down its back, and it is wearing a black collar around its neck.<|eot_id|>
```
##### Notes
- The `<|image|>` tag is used to indicate presence of the image
- The model isn't an early fusion model so doesn't actually translate an image into several tokens. Instead the cross-attention layers take input "on the side" from a vision encoder
![Image](mm-model.png)
- Its important to postion the <|image|> tag appropriately in the prompt. Image will only attend to the subsequent text tokens
- The <|image|> tag is part of the user message body, implying that it should only come after the header `<|start_header_id|>{role}<|end_header_id|>` in the message body
- We recommend using a single image in one prompt
## Builtin and Zero Shot Tool Calling
Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only.
Use `Environment: ipython` to enable tools.
Add `Tools: {{tool_name1}},{{tool_name2}}` for each of the builtin tools.
The same builtin tools as Llama3.1 are available,
- code_interpreter (for executing python code)
- brave_search (to search the web)
- wolfram_alpha (for querying wolfram alpha for mathematical questions)
##### Input Prompt Format
```
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Environment: ipython
Tools: brave_search, wolfram_alpha
Cutting Knowledge Date: December 2023
Today Date: 23 September 2024
You are a helpful assistant.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
```
##### Model Response Format
```
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
```
##### Notes
- Note the `<|python_tag|>` before `brave_search` function call.
- The `<|eom_id|>` tag is used to indicate the end of the message.
- Similar to Llama3.1, code_interpreter is not explicitly mentioned but is enabled via `Environment: ipython`.
- Tool Calling does NOT work with images in the prompt as of now.
## Prompt format for base models
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), the prompt format for a simple completion is as follows
##### Input Prompt Format
```
<|begin_of_text|>The color of the sky is blue but sometimes it can also be
```
##### Model Response Format
```
red, orange, pink, purple, and even black. The color of the sky is determined by the amount of sunlight that is scattered by the atmosphere and the amount of dust and water vapor present in the atmosphere. During sunrise and sunset, the sky can take on a range of colors due to the scattering of light by
```
##### Notes
- Same as Llama3.1
## Prompt format for base models with Image
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), here is an example of how the text completion format looks with an image,
##### Input Prompt Format
```
<|begin_of_text|><|image|>If I had to write a haiku for this one
```
##### Model Response Format
```
, it would be: A skateboarder's delight, a puppy on a board, a furry little thrill-seeker. This puppy is a true skateboarding enthusiast, always eager to hit the streets and show off his skills. He's a master of the board, gliding effortlessly across the pavement with grace and style.
```
##### Notes
- Note the placement of the special tags <|begin_of_text|> and <|image|>
Thank You!

View file

@ -14,7 +14,7 @@
import textwrap import textwrap
from typing import List from typing import List
from llama_models.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
RawMessage, RawMessage,
StopReason, StopReason,

View file

@ -16,7 +16,9 @@ import textwrap
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from llama_models.datatypes import ( from pydantic import BaseModel, Field
from llama_stack.models.llama.datatypes import (
RawContent, RawContent,
RawMediaItem, RawMediaItem,
RawMessage, RawMessage,
@ -25,7 +27,6 @@ from llama_models.datatypes import (
ToolCall, ToolCall,
ToolPromptFormat, ToolPromptFormat,
) )
from pydantic import BaseModel, Field
from .llama3.interface import LLama31Interface from .llama3.interface import LLama31Interface
from .llama3.template_data import ( from .llama3.template_data import (

View file

@ -61,7 +61,12 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import RAGDocument, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.tools import (
RAGDocument,
ToolGroups,
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
@ -120,13 +125,25 @@ class ChatAgent(ShieldRunnerMixin):
def turn_to_messages(self, turn: Turn) -> List[Message]: def turn_to_messages(self, turn: Turn) -> List[Message]:
messages = [] messages = []
# We do not want to keep adding RAG context to the input messages # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
# May be this should be a parameter of the agentic instance tool_call_ids = set()
# that can define its behavior in a custom way for step in turn.steps:
if step.step_type == StepType.tool_execution.value:
for response in step.tool_responses:
tool_call_ids.add(response.call_id)
for m in turn.input_messages: for m in turn.input_messages:
msg = m.model_copy() msg = m.model_copy()
# We do not want to keep adding RAG context to the input messages
# May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way
if isinstance(msg, UserMessage): if isinstance(msg, UserMessage):
msg.context = None msg.context = None
if isinstance(msg, ToolResponseMessage):
if msg.call_id in tool_call_ids:
# NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps
continue
messages.append(msg) messages.append(msg)
for step in turn.steps: for step in turn.steps:
@ -260,17 +277,24 @@ class ChatAgent(ShieldRunnerMixin):
raise ValueError(f"Session {request.session_id} not found") raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id) turns = await self.storage.get_session_turns(request.session_id)
if len(turns) == 0:
raise ValueError("No turns found for session")
messages = await self.get_messages_from_turns(turns) messages = await self.get_messages_from_turns(turns)
messages.extend(request.tool_responses) messages.extend(request.tool_responses)
last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn)
last_turn_messages = [ last_turn_messages = [
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
] ]
# TODO: figure out whether we should add the tool responses to the last turn messages
last_turn_messages.extend(request.tool_responses)
# get the steps from the turn id # get the steps from the turn id
steps = [] steps = []
if len(turns) > 0: steps = turns[-1].steps
steps = turns[-1].steps
# mark tool execution step as complete # mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client), # if there's no tool execution in progress step (due to storage, or tool call parsing on client),
@ -509,9 +533,15 @@ class ChatAgent(ShieldRunnerMixin):
if documents: if documents:
await self.handle_documents(session_id, documents, input_messages, tool_defs) await self.handle_documents(session_id, documents, input_messages, tool_defs)
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info and session_info.vector_db_id:
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
output_attachments = [] output_attachments = []
n_iter = 0 n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
# Build a map of custom tools to their definitions for faster lookup # Build a map of custom tools to their definitions for faster lookup
client_tools = {} client_tools = {}
for tool in self.agent_config.client_tools: for tool in self.agent_config.client_tools:
@ -586,8 +616,20 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None: if event.stop_reason is not None:
stop_reason = event.stop_reason stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason) span.set_attribute("stop_reason", stop_reason)
span.set_attribute("input", [m.model_dump_json() for m in input_messages]) span.set_attribute(
span.set_attribute("output", f"content: {content} tool_calls: {tool_calls}") "input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
)
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
}
)
span.set_attribute("output", output_attr)
n_iter += 1
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
stop_reason = stop_reason or StopReason.out_of_tokens stop_reason = stop_reason or StopReason.out_of_tokens
@ -624,6 +666,9 @@ class ChatAgent(ShieldRunnerMixin):
if n_iter >= self.agent_config.max_infer_iters: if n_iter >= self.agent_config.max_infer_iters:
log.info("Done with MAX iterations, exiting.") log.info("Done with MAX iterations, exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point
message.stop_reason = StopReason.end_of_turn
yield message yield message
break break
@ -673,6 +718,9 @@ class ChatAgent(ShieldRunnerMixin):
# If tool is a client tool, yield CompletionMessage and return # If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools: if tool_call.tool_name in client_tools:
# NOTE: mark end_of_message to indicate to client that it may
# call the tool and continue the conversation with the tool's response.
message.stop_reason = StopReason.end_of_message
await self.storage.set_in_progress_tool_call_step( await self.storage.set_in_progress_tool_call_step(
session_id, session_id,
turn_id, turn_id,
@ -758,16 +806,14 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message, result_message] input_messages = input_messages + [message, result_message]
n_iter += 1
async def _get_tool_defs( async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[List[ToolDefinition], Dict[str, str]]: ) -> Tuple[List[ToolDefinition], Dict[str, str]]:
# Determine which tools to include # Determine which tools to include
agent_config_toolgroups = set( agent_config_toolgroups = {
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup) toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
for toolgroup in self.agent_config.toolgroups for toolgroup in self.agent_config.toolgroups
) }
toolgroups_for_turn_set = ( toolgroups_for_turn_set = (
agent_config_toolgroups agent_config_toolgroups
if toolgroups_for_turn is None if toolgroups_for_turn is None
@ -803,6 +849,11 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
if not tools.data:
available_tool_groups = ", ".join(
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
)
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data): if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
raise ValueError( raise ValueError(
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}" f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
@ -1025,9 +1076,6 @@ async def execute_tool_call_maybe(
group_name = tool_to_group.get(name, None) group_name = tool_to_group.get(name, None)
if group_name is None: if group_name is None:
raise ValueError(f"Tool {name} not found in any tool group") raise ValueError(f"Tool {name} not found in any tool group")
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
tool_call_args = tool_call.arguments
tool_call_args.update(toolgroup_args.get(group_name, {}))
if isinstance(name, BuiltinTool): if isinstance(name, BuiltinTool):
if name == BuiltinTool.brave_search: if name == BuiltinTool.brave_search:
name = WEB_SEARCH_TOOL name = WEB_SEARCH_TOOL
@ -1036,10 +1084,12 @@ async def execute_tool_call_maybe(
result = await tool_runtime_api.invoke_tool( result = await tool_runtime_api.invoke_tool(
tool_name=name, tool_name=name,
kwargs=dict( kwargs={
session_id=session_id, "session_id": session_id,
**tool_call_args, # get the arguments generated by the model and augment with toolgroup arg overrides for the agent
), **tool_call.arguments,
**toolgroup_args.get(group_name, {}),
},
) )
return result return result

View file

@ -194,17 +194,13 @@ class MetaReferenceAgentsImpl(Agents):
yield event yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") agent = await self.get_agent(agent_id)
turn = json.loads(turn) turn = await agent.storage.get_session_turn(session_id, turn_id)
turn = Turn(**turn)
return turn return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse: async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") turn = await self.get_agents_turn(agent_id, session_id, turn_id)
turn = json.loads(turn) for step in turn.steps:
turn = Turn(**turn)
steps = turn.steps
for step in steps:
if step.step_id == step_id: if step.step_id == step_id:
return AgentStepResponse(step=step) return AgentStepResponse(step=step)
raise ValueError(f"Provided step_id {step_id} could not be found") raise ValueError(f"Provided step_id {step_id} could not be found")
@ -215,20 +211,18 @@ class MetaReferenceAgentsImpl(Agents):
session_id: str, session_id: str,
turn_ids: Optional[List[str]] = None, turn_ids: Optional[List[str]] = None,
) -> Session: ) -> Session:
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}") agent = await self.get_agent(agent_id)
session = Session(**json.loads(session), turns=[]) session_info = await agent.storage.get_session_info(session_id)
turns = [] if session_info is None:
raise ValueError(f"Session {session_id} not found")
turns = await agent.storage.get_session_turns(session_id)
if turn_ids: if turn_ids:
for turn_id in turn_ids: turns = [turn for turn in turns if turn.turn_id in turn_ids]
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
turn = json.loads(turn)
turn = Turn(**turn)
turns.append(turn)
return Session( return Session(
session_name=session.session_name, session_name=session_info.session_name,
session_id=session_id, session_id=session_id,
turns=turns if turns else [], turns=turns,
started_at=session.started_at, started_at=session_info.started_at,
) )
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: async def delete_agents_session(self, agent_id: str, session_id: str) -> None:

View file

@ -21,6 +21,7 @@ log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel): class AgentSessionInfo(BaseModel):
session_id: str session_id: str
session_name: str session_name: str
# TODO: is this used anywhere?
vector_db_id: Optional[str] = None vector_db_id: Optional[str] = None
started_at: datetime started_at: datetime
@ -85,6 +86,14 @@ class AgentPersistence:
turns.sort(key=lambda x: (x.completed_at or datetime.min)) turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
)
if not value:
return None
return Turn(**json.loads(value))
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
await self.kvstore.set( await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
@ -96,3 +105,15 @@ class AgentPersistence:
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
) )
return ToolExecutionStep(**json.loads(value)) if value else None return ToolExecutionStep(**json.loads(value)) if value else None
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
await self.kvstore.set(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
value=str(num_infer_iters),
)
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
value = await self.kvstore.get(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)
return int(value) if value else None

View file

@ -3,6 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from tqdm import tqdm from tqdm import tqdm
@ -86,7 +87,6 @@ class MetaReferenceEvalImpl(
) -> Job: ) -> Job:
task_def = self.benchmarks[benchmark_id] task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
@ -117,7 +117,7 @@ class MetaReferenceEvalImpl(
generations = [] generations = []
for i, x in tqdm(enumerate(input_rows)): for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row" assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = eval(str(x[ColumnName.chat_completion_input.value])) input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in input_messages] input_messages = [UserMessage(**x) for x in input_messages]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row # NOTE: only single-turn agent generation is supported. Create a new session for each input row
@ -159,7 +159,7 @@ class MetaReferenceEvalImpl(
generations = [] generations = []
for x in tqdm(input_rows): for x in tqdm(input_rows):
if ColumnName.completion_input.value in x: if ColumnName.completion_input.value in x:
input_content = eval(str(x[ColumnName.completion_input.value])) input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.completion( response = await self.inference_api.completion(
model=candidate.model, model=candidate.model,
content=input_content, content=input_content,
@ -167,9 +167,8 @@ class MetaReferenceEvalImpl(
) )
generations.append({ColumnName.generated_answer.value: response.completion_message.content}) generations.append({ColumnName.generated_answer.value: response.completion_message.content})
elif ColumnName.chat_completion_input.value in x: elif ColumnName.chat_completion_input.value in x:
chat_completion_input_str = str(x[ColumnName.chat_completion_input.value]) chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = eval(chat_completion_input_str) input_messages = [UserMessage(**x) for x in chat_completion_input_json]
input_messages = [UserMessage(**x) for x in input_messages]
messages = [] messages = []
if candidate.system_message: if candidate.system_message:
messages.append(candidate.system_message) messages.append(candidate.system_message)

View file

@ -23,13 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel, initialize_model_parallel,
model_parallel_is_initialized, model_parallel_is_initialized,
) )
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel from pydantic import BaseModel
@ -46,6 +39,13 @@ from llama_stack.models.llama.datatypes import (
SamplingParams, SamplingParams,
TopPSamplingStrategy, TopPSamplingStrategy,
) )
from llama_stack.models.llama.llama3.args import ModelArgs
from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput
from llama_stack.models.llama.llama3.model import Transformer
from llama_stack.models.llama.llama3.multimodal.model import (
CrossAttentionTransformer,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,

View file

@ -111,7 +111,7 @@ class MetaReferenceInferenceImpl(
) )
if llama_model is None: if llama_model is None:
raise ValueError( raise ValueError(
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list" "Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
) )
self.model_registry_helper = ModelRegistryHelper( self.model_registry_helper = ModelRegistryHelper(
@ -208,7 +208,6 @@ class MetaReferenceInferenceImpl(
logprobs = [] logprobs = []
stop_reason = None stop_reason = None
tokenizer = self.generator.formatter.tokenizer
for token_result in self.generator.completion(request): for token_result in self.generator.completion(request):
tokens.append(token_result.token) tokens.append(token_result.token)
if token_result.text == "<|eot_id|>": if token_result.text == "<|eot_id|>":

View file

@ -9,10 +9,9 @@ from copy import deepcopy
from functools import partial from functools import partial
from typing import Any, Generator from typing import Any, Generator
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.models.llama.datatypes import Model from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,

View file

@ -207,7 +207,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
return parse_message(maybe_json) return parse_message(maybe_json)
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
except ValueError as e: except ValueError:
return None return None
@ -352,7 +352,7 @@ class ModelParallelProcessGroup:
if isinstance(obj, TaskResponse): if isinstance(obj, TaskResponse):
yield obj.result yield obj.result
except GeneratorExit as e: except GeneratorExit:
self.request_socket.send(encode_msg(CancelSentinel())) self.request_socket.send(encode_msg(CancelSentinel()))
while True: while True:
obj_json = self.request_socket.send() obj_json = self.request_socket.send()

Some files were not shown because too many files have changed in this diff Show more