mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
Merge branch 'meta-llama:main' into fix-ollama-rag
This commit is contained in:
commit
2868d8f793
215 changed files with 137658 additions and 2561 deletions
|
@ -8,6 +8,8 @@ repos:
|
||||||
rev: v5.0.0 # Latest stable version
|
rev: v5.0.0 # Latest stable version
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-merge-conflict
|
- id: check-merge-conflict
|
||||||
|
- id: trailing-whitespace
|
||||||
|
exclude: '\.py$' # Exclude Python files as Ruff already handles them
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
args: ['--maxkb=1000']
|
args: ['--maxkb=1000']
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
|
@ -83,10 +85,8 @@ repos:
|
||||||
- id: distro-codegen
|
- id: distro-codegen
|
||||||
name: Distribution Template Codegen
|
name: Distribution Template Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- rich
|
|
||||||
- pydantic
|
|
||||||
- uv==0.6.0
|
- uv==0.6.0
|
||||||
entry: uv run python -m llama_stack.scripts.distro_codegen
|
entry: uv run --extra codegen python -m llama_stack.scripts.distro_codegen
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
|
@ -70,6 +70,19 @@ $ uv pip install -e .
|
||||||
$ source .venv/bin/activate
|
$ source .venv/bin/activate
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Note that you can create a dotenv file `.env` that includes necessary environment variables:
|
||||||
|
```
|
||||||
|
LLAMA_STACK_BASE_URL=http://localhost:8321
|
||||||
|
LLAMA_STACK_CLIENT_LOG=debug
|
||||||
|
LLAMA_STACK_PORT=8321
|
||||||
|
LLAMA_STACK_CONFIG=
|
||||||
|
```
|
||||||
|
|
||||||
|
And then use this dotenv file when running client SDK tests via the following:
|
||||||
|
```bash
|
||||||
|
$ uv run --env-file .env -- pytest -v tests/client-sdk/inference/test_text_inference.py
|
||||||
|
```
|
||||||
|
|
||||||
## Pre-commit Hooks
|
## Pre-commit Hooks
|
||||||
|
|
||||||
We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running:
|
We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running:
|
||||||
|
@ -110,15 +123,15 @@ Some tips about common tasks you work on while contributing to Llama Stack:
|
||||||
|
|
||||||
### Using `llama stack build`
|
### Using `llama stack build`
|
||||||
|
|
||||||
Building a stack image (conda / docker) will use the production version of the `llama-stack`, `llama-models` and `llama-stack-client` packages. If you are developing with a llama-stack repository checked out and need your code to be reflected in the stack image, set `LLAMA_STACK_DIR` and `LLAMA_MODELS_DIR` to the appropriate checked out directories when running any of the `llama` CLI commands.
|
Building a stack image (conda / docker) will use the production version of the `llama-stack` and `llama-stack-client` packages. If you are developing with a llama-stack repository checked out and need your code to be reflected in the stack image, set `LLAMA_STACK_DIR` and `LLAMA_STACK_CLIENT_DIR` to the appropriate checked out directories when running any of the `llama` CLI commands.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```bash
|
```bash
|
||||||
$ cd work/
|
$ cd work/
|
||||||
$ git clone https://github.com/meta-llama/llama-stack.git
|
$ git clone https://github.com/meta-llama/llama-stack.git
|
||||||
$ git clone https://github.com/meta-llama/llama-models.git
|
$ git clone https://github.com/meta-llama/llama-stack-client-python.git
|
||||||
$ cd llama-stack
|
$ cd llama-stack
|
||||||
$ LLAMA_STACK_DIR=$(pwd) LLAMA_MODELS_DIR=../llama-models llama stack build --template <...>
|
$ LLAMA_STACK_DIR=$(pwd) LLAMA_STACK_CLIENT_DIR=../llama-stack-client-python llama stack build --template <...>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -216,8 +216,8 @@
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"groq",
|
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"litellm",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"nltk",
|
"nltk",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
@ -431,6 +431,7 @@
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
"numpy",
|
"numpy",
|
||||||
"ollama",
|
"ollama",
|
||||||
|
|
1076
docs/_static/llama-stack-spec.html
vendored
1076
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
894
docs/_static/llama-stack-spec.yaml
vendored
894
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
File diff suppressed because one or more lines are too long
|
@ -45,65 +45,7 @@
|
||||||
"id": "O9pGVlPIjpix",
|
"id": "O9pGVlPIjpix",
|
||||||
"outputId": "e1fbe723-ae31-4630-eb80-4c4f6476d56f"
|
"outputId": "e1fbe723-ae31-4630-eb80-4c4f6476d56f"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Requirement already satisfied: llama-stack in /usr/local/lib/python3.10/dist-packages (0.0.61)\n",
|
|
||||||
"Requirement already satisfied: blobfile in /usr/local/lib/python3.10/dist-packages (from llama-stack) (3.0.0)\n",
|
|
||||||
"Requirement already satisfied: fire in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.7.0)\n",
|
|
||||||
"Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.28.1)\n",
|
|
||||||
"Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.26.5)\n",
|
|
||||||
"Requirement already satisfied: llama-models>=0.0.61 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.0.61)\n",
|
|
||||||
"Requirement already satisfied: llama-stack-client>=0.0.61 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.0.61)\n",
|
|
||||||
"Requirement already satisfied: prompt-toolkit in /usr/local/lib/python3.10/dist-packages (from llama-stack) (3.0.48)\n",
|
|
||||||
"Requirement already satisfied: python-dotenv in /usr/local/lib/python3.10/dist-packages (from llama-stack) (1.0.1)\n",
|
|
||||||
"Requirement already satisfied: pydantic>=2 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.10.3)\n",
|
|
||||||
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.32.3)\n",
|
|
||||||
"Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from llama-stack) (13.9.4)\n",
|
|
||||||
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from llama-stack) (75.1.0)\n",
|
|
||||||
"Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.5.0)\n",
|
|
||||||
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (6.0.2)\n",
|
|
||||||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (3.1.4)\n",
|
|
||||||
"Requirement already satisfied: tiktoken in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (0.8.0)\n",
|
|
||||||
"Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (10.4.0)\n",
|
|
||||||
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (3.7.1)\n",
|
|
||||||
"Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (8.1.7)\n",
|
|
||||||
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (1.9.0)\n",
|
|
||||||
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (2.2.2)\n",
|
|
||||||
"Requirement already satisfied: pyaml in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (24.12.1)\n",
|
|
||||||
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (1.3.1)\n",
|
|
||||||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (4.66.6)\n",
|
|
||||||
"Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (4.12.2)\n",
|
|
||||||
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (2024.8.30)\n",
|
|
||||||
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (1.0.7)\n",
|
|
||||||
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (3.10)\n",
|
|
||||||
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx->llama-stack) (0.14.0)\n",
|
|
||||||
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2->llama-stack) (0.7.0)\n",
|
|
||||||
"Requirement already satisfied: pydantic-core==2.27.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2->llama-stack) (2.27.1)\n",
|
|
||||||
"Requirement already satisfied: pycryptodomex>=3.8 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (3.21.0)\n",
|
|
||||||
"Requirement already satisfied: urllib3<3,>=1.25.3 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (2.2.3)\n",
|
|
||||||
"Requirement already satisfied: lxml>=4.9 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (5.3.0)\n",
|
|
||||||
"Requirement already satisfied: filelock>=3.0 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (3.16.1)\n",
|
|
||||||
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->llama-stack) (2024.9.0)\n",
|
|
||||||
"Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->llama-stack) (24.2)\n",
|
|
||||||
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit->llama-stack) (0.2.13)\n",
|
|
||||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->llama-stack) (3.4.0)\n",
|
|
||||||
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->llama-stack) (3.0.0)\n",
|
|
||||||
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->llama-stack) (2.18.0)\n",
|
|
||||||
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->llama-stack-client>=0.0.61->llama-stack) (1.2.2)\n",
|
|
||||||
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->llama-stack) (0.1.2)\n",
|
|
||||||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->llama-models>=0.0.61->llama-stack) (3.0.2)\n",
|
|
||||||
"Requirement already satisfied: numpy>=1.22.4 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (1.26.4)\n",
|
|
||||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2.8.2)\n",
|
|
||||||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2024.2)\n",
|
|
||||||
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2024.2)\n",
|
|
||||||
"Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.10/dist-packages (from tiktoken->llama-models>=0.0.61->llama-stack) (2024.9.11)\n",
|
|
||||||
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->llama-stack-client>=0.0.61->llama-stack) (1.17.0)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# NBVAL_SKIP\n",
|
"# NBVAL_SKIP\n",
|
||||||
"!pip install -U llama-stack"
|
"!pip install -U llama-stack"
|
||||||
|
@ -120,198 +62,10 @@
|
||||||
"id": "JQpLUSNjlGAM",
|
"id": "JQpLUSNjlGAM",
|
||||||
"outputId": "2f7fec97-5511-4cae-d51e-6d262fbca19c"
|
"outputId": "2f7fec97-5511-4cae-d51e-6d262fbca19c"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Requirement already satisfied: llama-stack in /usr/local/lib/python3.10/dist-packages (0.0.61)\r\n",
|
|
||||||
"Requirement already satisfied: blobfile in /usr/local/lib/python3.10/dist-packages (from llama-stack) (3.0.0)\r\n",
|
|
||||||
"Requirement already satisfied: fire in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.7.0)\r\n",
|
|
||||||
"Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.28.1)\r\n",
|
|
||||||
"Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.26.5)\r\n",
|
|
||||||
"Requirement already satisfied: llama-models>=0.0.61 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.0.61)\r\n",
|
|
||||||
"Requirement already satisfied: llama-stack-client>=0.0.61 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.0.61)\r\n",
|
|
||||||
"Requirement already satisfied: prompt-toolkit in /usr/local/lib/python3.10/dist-packages (from llama-stack) (3.0.48)\r\n",
|
|
||||||
"Requirement already satisfied: python-dotenv in /usr/local/lib/python3.10/dist-packages (from llama-stack) (1.0.1)\r\n",
|
|
||||||
"Requirement already satisfied: pydantic>=2 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.10.3)\r\n",
|
|
||||||
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.32.3)\r\n",
|
|
||||||
"Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from llama-stack) (13.9.4)\r\n",
|
|
||||||
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from llama-stack) (75.1.0)\r\n",
|
|
||||||
"Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.5.0)\r\n",
|
|
||||||
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (6.0.2)\r\n",
|
|
||||||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (3.1.4)\r\n",
|
|
||||||
"Requirement already satisfied: tiktoken in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (0.8.0)\r\n",
|
|
||||||
"Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (10.4.0)\r\n",
|
|
||||||
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (3.7.1)\r\n",
|
|
||||||
"Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (8.1.7)\r\n",
|
|
||||||
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (1.9.0)\r\n",
|
|
||||||
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (2.2.2)\r\n",
|
|
||||||
"Requirement already satisfied: pyaml in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (24.12.1)\r\n",
|
|
||||||
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (1.3.1)\r\n",
|
|
||||||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (4.66.6)\r\n",
|
|
||||||
"Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (4.12.2)\r\n",
|
|
||||||
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (2024.8.30)\r\n",
|
|
||||||
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (1.0.7)\r\n",
|
|
||||||
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (3.10)\r\n",
|
|
||||||
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx->llama-stack) (0.14.0)\r\n",
|
|
||||||
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2->llama-stack) (0.7.0)\r\n",
|
|
||||||
"Requirement already satisfied: pydantic-core==2.27.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2->llama-stack) (2.27.1)\r\n",
|
|
||||||
"Requirement already satisfied: pycryptodomex>=3.8 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (3.21.0)\r\n",
|
|
||||||
"Requirement already satisfied: urllib3<3,>=1.25.3 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (2.2.3)\r\n",
|
|
||||||
"Requirement already satisfied: lxml>=4.9 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (5.3.0)\r\n",
|
|
||||||
"Requirement already satisfied: filelock>=3.0 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (3.16.1)\r\n",
|
|
||||||
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->llama-stack) (2024.9.0)\r\n",
|
|
||||||
"Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->llama-stack) (24.2)\r\n",
|
|
||||||
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit->llama-stack) (0.2.13)\r\n",
|
|
||||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->llama-stack) (3.4.0)\r\n",
|
|
||||||
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->llama-stack) (3.0.0)\r\n",
|
|
||||||
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->llama-stack) (2.18.0)\r\n",
|
|
||||||
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->llama-stack-client>=0.0.61->llama-stack) (1.2.2)\n",
|
|
||||||
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->llama-stack) (0.1.2)\n",
|
|
||||||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->llama-models>=0.0.61->llama-stack) (3.0.2)\n",
|
|
||||||
"Requirement already satisfied: numpy>=1.22.4 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (1.26.4)\n",
|
|
||||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2.8.2)\n",
|
|
||||||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2024.2)\n",
|
|
||||||
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2024.2)\n",
|
|
||||||
"Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.10/dist-packages (from tiktoken->llama-models>=0.0.61->llama-stack) (2024.9.11)\n",
|
|
||||||
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->llama-stack-client>=0.0.61->llama-stack) (1.17.0)\n",
|
|
||||||
"Installing pip dependencies\n",
|
|
||||||
"Requirement already satisfied: blobfile in /usr/local/lib/python3.10/dist-packages (3.0.0)\n",
|
|
||||||
"Requirement already satisfied: chardet in /usr/local/lib/python3.10/dist-packages (5.2.0)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-sdk in /usr/local/lib/python3.10/dist-packages (1.28.2)\n",
|
|
||||||
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.13.1)\n",
|
|
||||||
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (2.2.2)\n",
|
|
||||||
"Requirement already satisfied: autoevals in /usr/local/lib/python3.10/dist-packages (0.0.109)\n",
|
|
||||||
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.2.0)\n",
|
|
||||||
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.5.2)\n",
|
|
||||||
"Requirement already satisfied: pillow in /usr/local/lib/python3.10/dist-packages (10.4.0)\n",
|
|
||||||
"Requirement already satisfied: pypdf in /usr/local/lib/python3.10/dist-packages (5.1.0)\n",
|
|
||||||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.6)\n",
|
|
||||||
"Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (3.9.1)\n",
|
|
||||||
"Requirement already satisfied: aiosqlite in /usr/local/lib/python3.10/dist-packages (0.20.0)\n",
|
|
||||||
"Requirement already satisfied: psycopg2-binary in /usr/local/lib/python3.10/dist-packages (2.9.10)\n",
|
|
||||||
"Requirement already satisfied: faiss-cpu in /usr/local/lib/python3.10/dist-packages (1.9.0.post1)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-exporter-otlp-proto-http in /usr/local/lib/python3.10/dist-packages (1.28.2)\n",
|
|
||||||
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.46.3)\n",
|
|
||||||
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.26.4)\n",
|
|
||||||
"Requirement already satisfied: chromadb-client in /usr/local/lib/python3.10/dist-packages (0.5.23)\n",
|
|
||||||
"Requirement already satisfied: openai in /usr/local/lib/python3.10/dist-packages (1.54.5)\n",
|
|
||||||
"Requirement already satisfied: redis in /usr/local/lib/python3.10/dist-packages (5.2.1)\n",
|
|
||||||
"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.2.0)\n",
|
|
||||||
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.8.0)\n",
|
|
||||||
"Requirement already satisfied: together in /usr/local/lib/python3.10/dist-packages (1.3.5)\n",
|
|
||||||
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (0.115.6)\n",
|
|
||||||
"Requirement already satisfied: fire in /usr/local/lib/python3.10/dist-packages (0.7.0)\n",
|
|
||||||
"Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (0.28.1)\n",
|
|
||||||
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (0.32.1)\n",
|
|
||||||
"Requirement already satisfied: pycryptodomex>=3.8 in /usr/local/lib/python3.10/dist-packages (from blobfile) (3.21.0)\n",
|
|
||||||
"Requirement already satisfied: urllib3<3,>=1.25.3 in /usr/local/lib/python3.10/dist-packages (from blobfile) (2.2.3)\n",
|
|
||||||
"Requirement already satisfied: lxml>=4.9 in /usr/local/lib/python3.10/dist-packages (from blobfile) (5.3.0)\n",
|
|
||||||
"Requirement already satisfied: filelock>=3.0 in /usr/local/lib/python3.10/dist-packages (from blobfile) (3.16.1)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-api==1.28.2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-sdk) (1.28.2)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-semantic-conventions==0.49b2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-sdk) (0.49b2)\n",
|
|
||||||
"Requirement already satisfied: typing-extensions>=3.7.4 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-sdk) (4.12.2)\n",
|
|
||||||
"Requirement already satisfied: deprecated>=1.2.6 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-api==1.28.2->opentelemetry-sdk) (1.2.15)\n",
|
|
||||||
"Requirement already satisfied: importlib-metadata<=8.5.0,>=6.0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-api==1.28.2->opentelemetry-sdk) (8.5.0)\n",
|
|
||||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2)\n",
|
|
||||||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2024.2)\n",
|
|
||||||
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas) (2024.2)\n",
|
|
||||||
"Requirement already satisfied: chevron in /usr/local/lib/python3.10/dist-packages (from autoevals) (0.14.0)\n",
|
|
||||||
"Requirement already satisfied: levenshtein in /usr/local/lib/python3.10/dist-packages (from autoevals) (0.26.1)\n",
|
|
||||||
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from autoevals) (6.0.2)\n",
|
|
||||||
"Requirement already satisfied: braintrust_core==0.0.54 in /usr/local/lib/python3.10/dist-packages (from autoevals) (0.0.54)\n",
|
|
||||||
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from autoevals) (4.23.0)\n",
|
|
||||||
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.4.2)\n",
|
|
||||||
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.5.0)\n",
|
|
||||||
"Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk) (8.1.7)\n",
|
|
||||||
"Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk) (2024.9.11)\n",
|
|
||||||
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from faiss-cpu) (24.2)\n",
|
|
||||||
"Requirement already satisfied: googleapis-common-protos~=1.52 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-http) (1.66.0)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-exporter-otlp-proto-common==1.28.2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-http) (1.28.2)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-proto==1.28.2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-http) (1.28.2)\n",
|
|
||||||
"Requirement already satisfied: requests~=2.7 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-http) (2.32.3)\n",
|
|
||||||
"Requirement already satisfied: protobuf<6.0,>=5.0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-proto==1.28.2->opentelemetry-exporter-otlp-proto-http) (5.29.1)\n",
|
|
||||||
"Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.26.5)\n",
|
|
||||||
"Requirement already satisfied: tokenizers<0.21,>=0.20 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.20.3)\n",
|
|
||||||
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-exporter-otlp-proto-grpc>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (1.28.2)\n",
|
|
||||||
"Requirement already satisfied: overrides>=7.3.1 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (7.7.0)\n",
|
|
||||||
"Requirement already satisfied: posthog>=2.4.0 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (3.7.4)\n",
|
|
||||||
"Requirement already satisfied: pydantic>=1.9 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (2.10.3)\n",
|
|
||||||
"Requirement already satisfied: tenacity>=8.2.3 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (9.0.0)\n",
|
|
||||||
"Requirement already satisfied: orjson>=3.9.12 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (3.10.12)\n",
|
|
||||||
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai) (3.7.1)\n",
|
|
||||||
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from openai) (1.9.0)\n",
|
|
||||||
"Requirement already satisfied: jiter<1,>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from openai) (0.8.2)\n",
|
|
||||||
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai) (1.3.1)\n",
|
|
||||||
"Requirement already satisfied: async-timeout>=4.0.3 in /usr/local/lib/python3.10/dist-packages (from redis) (4.0.3)\n",
|
|
||||||
"Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n",
|
|
||||||
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n",
|
|
||||||
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n",
|
|
||||||
"Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n",
|
|
||||||
"Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n",
|
|
||||||
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.10)\n",
|
|
||||||
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.3.1)\n",
|
|
||||||
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.12.1)\n",
|
|
||||||
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.55.2)\n",
|
|
||||||
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.7)\n",
|
|
||||||
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (3.2.0)\n",
|
|
||||||
"Requirement already satisfied: eval-type-backport<0.3.0,>=0.1.3 in /usr/local/lib/python3.10/dist-packages (from together) (0.2.0)\n",
|
|
||||||
"Requirement already satisfied: rich<14.0.0,>=13.8.1 in /usr/local/lib/python3.10/dist-packages (from together) (13.9.4)\n",
|
|
||||||
"Requirement already satisfied: tabulate<0.10.0,>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from together) (0.9.0)\n",
|
|
||||||
"Requirement already satisfied: typer<0.14,>=0.9 in /usr/local/lib/python3.10/dist-packages (from together) (0.13.1)\n",
|
|
||||||
"Requirement already satisfied: starlette<0.42.0,>=0.40.0 in /usr/local/lib/python3.10/dist-packages (from fastapi) (0.41.3)\n",
|
|
||||||
"Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from fire) (2.5.0)\n",
|
|
||||||
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx) (2024.8.30)\n",
|
|
||||||
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx) (1.0.7)\n",
|
|
||||||
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx) (3.10)\n",
|
|
||||||
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx) (0.14.0)\n",
|
|
||||||
"Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.4)\n",
|
|
||||||
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
|
|
||||||
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n",
|
|
||||||
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n",
|
|
||||||
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n",
|
|
||||||
"Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.1)\n",
|
|
||||||
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.18.3)\n",
|
|
||||||
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (1.2.2)\n",
|
|
||||||
"Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.10/dist-packages (from deprecated>=1.2.6->opentelemetry-api==1.28.2->opentelemetry-sdk) (1.17.0)\n",
|
|
||||||
"Requirement already satisfied: grpcio<2.0.0,>=1.63.2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb-client) (1.68.1)\n",
|
|
||||||
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from posthog>=2.4.0->chromadb-client) (1.17.0)\n",
|
|
||||||
"Requirement already satisfied: monotonic>=1.5 in /usr/local/lib/python3.10/dist-packages (from posthog>=2.4.0->chromadb-client) (1.6)\n",
|
|
||||||
"Requirement already satisfied: backoff>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from posthog>=2.4.0->chromadb-client) (2.2.1)\n",
|
|
||||||
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.9->chromadb-client) (0.7.0)\n",
|
|
||||||
"Requirement already satisfied: pydantic-core==2.27.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.9->chromadb-client) (2.27.1)\n",
|
|
||||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests~=2.7->opentelemetry-exporter-otlp-proto-http) (3.4.0)\n",
|
|
||||||
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich<14.0.0,>=13.8.1->together) (3.0.0)\n",
|
|
||||||
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich<14.0.0,>=13.8.1->together) (2.18.0)\n",
|
|
||||||
"Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from typer<0.14,>=0.9->together) (1.5.4)\n",
|
|
||||||
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->autoevals) (2024.10.1)\n",
|
|
||||||
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->autoevals) (0.35.1)\n",
|
|
||||||
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->autoevals) (0.22.3)\n",
|
|
||||||
"Requirement already satisfied: rapidfuzz<4.0.0,>=3.9.0 in /usr/local/lib/python3.10/dist-packages (from levenshtein->autoevals) (3.10.1)\n",
|
|
||||||
"Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.10/dist-packages (from importlib-metadata<=8.5.0,>=6.0->opentelemetry-api==1.28.2->opentelemetry-sdk) (3.21.0)\n",
|
|
||||||
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich<14.0.0,>=13.8.1->together) (0.1.2)\n",
|
|
||||||
"sentence-transformers --no-deps\n",
|
|
||||||
"Requirement already satisfied: sentence-transformers in /usr/local/lib/python3.10/dist-packages (3.2.1)\n",
|
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu\n",
|
|
||||||
"Looking in indexes: https://download.pytorch.org/whl/cpu\n",
|
|
||||||
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)\n",
|
|
||||||
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n",
|
|
||||||
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n",
|
|
||||||
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)\n",
|
|
||||||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n",
|
|
||||||
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.9.0)\n",
|
|
||||||
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n",
|
|
||||||
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
|
|
||||||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)\n",
|
|
||||||
"\u001b[32mBuild Successful!\u001b[0m\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# NBVAL_SKIP\n",
|
"# NBVAL_SKIP\n",
|
||||||
"!llama stack build --template together --image-type venv --image-name __system__"
|
"!UV_SYSTEM_PYTHON=1 llama stack build --template together --image-type venv"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -3,7 +3,7 @@ The RFC Specification (OpenAPI format) is generated from the set of API endpoint
|
||||||
Please install the following packages before running the script:
|
Please install the following packages before running the script:
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install fire PyYAML llama-models
|
pip install fire PyYAML
|
||||||
```
|
```
|
||||||
|
|
||||||
Then simply run `sh run_openapi_generator.sh`
|
Then simply run `sh run_openapi_generator.sh`
|
||||||
|
|
|
@ -55,6 +55,7 @@ def main(output_dir: str):
|
||||||
a set of endpoints and their corresponding interfaces that are tailored to
|
a set of endpoints and their corresponding interfaces that are tailored to
|
||||||
best leverage Llama Models.""",
|
best leverage Llama Models.""",
|
||||||
),
|
),
|
||||||
|
include_standard_error_responses=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ import typing
|
||||||
from dataclasses import make_dataclass
|
from dataclasses import make_dataclass
|
||||||
from typing import Any, Dict, Set, Union
|
from typing import Any, Dict, Set, Union
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Error
|
||||||
from llama_stack.strong_typing.core import JsonType
|
from llama_stack.strong_typing.core import JsonType
|
||||||
from llama_stack.strong_typing.docstring import Docstring, parse_type
|
from llama_stack.strong_typing.docstring import Docstring, parse_type
|
||||||
from llama_stack.strong_typing.inspection import (
|
from llama_stack.strong_typing.inspection import (
|
||||||
|
@ -435,6 +436,75 @@ class Generator:
|
||||||
self.schema_builder = SchemaBuilder(schema_generator)
|
self.schema_builder = SchemaBuilder(schema_generator)
|
||||||
self.responses = {}
|
self.responses = {}
|
||||||
|
|
||||||
|
# Create standard error responses
|
||||||
|
self._create_standard_error_responses()
|
||||||
|
|
||||||
|
def _create_standard_error_responses(self) -> None:
|
||||||
|
"""
|
||||||
|
Creates standard error responses that can be reused across operations.
|
||||||
|
These will be added to the components.responses section of the OpenAPI document.
|
||||||
|
"""
|
||||||
|
# Get the Error schema
|
||||||
|
error_schema = self.schema_builder.classdef_to_ref(Error)
|
||||||
|
|
||||||
|
# Create standard error responses
|
||||||
|
self.responses["BadRequest400"] = Response(
|
||||||
|
description="The request was invalid or malformed",
|
||||||
|
content={
|
||||||
|
"application/json": MediaType(
|
||||||
|
schema=error_schema,
|
||||||
|
example={
|
||||||
|
"status": 400,
|
||||||
|
"title": "Bad Request",
|
||||||
|
"detail": "The request was invalid or malformed",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.responses["TooManyRequests429"] = Response(
|
||||||
|
description="The client has sent too many requests in a given amount of time",
|
||||||
|
content={
|
||||||
|
"application/json": MediaType(
|
||||||
|
schema=error_schema,
|
||||||
|
example={
|
||||||
|
"status": 429,
|
||||||
|
"title": "Too Many Requests",
|
||||||
|
"detail": "You have exceeded the rate limit. Please try again later.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.responses["InternalServerError500"] = Response(
|
||||||
|
description="The server encountered an unexpected error",
|
||||||
|
content={
|
||||||
|
"application/json": MediaType(
|
||||||
|
schema=error_schema,
|
||||||
|
example={
|
||||||
|
"status": 500,
|
||||||
|
"title": "Internal Server Error",
|
||||||
|
"detail": "An unexpected error occurred. Our team has been notified.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a default error response for any unhandled error cases
|
||||||
|
self.responses["DefaultError"] = Response(
|
||||||
|
description="An unexpected error occurred",
|
||||||
|
content={
|
||||||
|
"application/json": MediaType(
|
||||||
|
schema=error_schema,
|
||||||
|
example={
|
||||||
|
"status": 0,
|
||||||
|
"title": "Error",
|
||||||
|
"detail": "An unexpected error occurred",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
|
def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
|
||||||
# Don't include schema definition in the tag description because for one,
|
# Don't include schema definition in the tag description because for one,
|
||||||
# it is not very valuable and for another, it causes string formatting
|
# it is not very valuable and for another, it causes string formatting
|
||||||
|
@ -649,6 +719,18 @@ class Generator:
|
||||||
responses.update(response_builder.build_response(response_options))
|
responses.update(response_builder.build_response(response_options))
|
||||||
|
|
||||||
assert len(responses.keys()) > 0, f"No responses found for {op.name}"
|
assert len(responses.keys()) > 0, f"No responses found for {op.name}"
|
||||||
|
|
||||||
|
# Add standard error response references
|
||||||
|
if self.options.include_standard_error_responses:
|
||||||
|
if "400" not in responses:
|
||||||
|
responses["400"] = ResponseRef("BadRequest400")
|
||||||
|
if "429" not in responses:
|
||||||
|
responses["429"] = ResponseRef("TooManyRequests429")
|
||||||
|
if "500" not in responses:
|
||||||
|
responses["500"] = ResponseRef("InternalServerError500")
|
||||||
|
if "default" not in responses:
|
||||||
|
responses["default"] = ResponseRef("DefaultError")
|
||||||
|
|
||||||
if op.event_type is not None:
|
if op.event_type is not None:
|
||||||
builder = ContentBuilder(self.schema_builder)
|
builder = ContentBuilder(self.schema_builder)
|
||||||
callbacks = {
|
callbacks = {
|
||||||
|
|
|
@ -35,6 +35,7 @@ class Options:
|
||||||
:param error_wrapper: True if errors are encapsulated in an error object wrapper.
|
:param error_wrapper: True if errors are encapsulated in an error object wrapper.
|
||||||
:param property_description_fun: Custom transformation function to apply to class property documentation strings.
|
:param property_description_fun: Custom transformation function to apply to class property documentation strings.
|
||||||
:param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types.
|
:param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types.
|
||||||
|
:param include_standard_error_responses: Whether to include standard error responses (400, 429, 500, 503) in all operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
server: Server
|
server: Server
|
||||||
|
@ -52,6 +53,7 @@ class Options:
|
||||||
error_wrapper: bool = False
|
error_wrapper: bool = False
|
||||||
property_description_fun: Optional[Callable[[type, str, str], str]] = None
|
property_description_fun: Optional[Callable[[type, str, str], str]] = None
|
||||||
captions: Optional[Dict[str, str]] = None
|
captions: Optional[Dict[str, str]] = None
|
||||||
|
include_standard_error_responses: bool = True
|
||||||
|
|
||||||
default_captions: ClassVar[Dict[str, str]] = {
|
default_captions: ClassVar[Dict[str, str]] = {
|
||||||
"Operations": "Operations",
|
"Operations": "Operations",
|
||||||
|
|
|
@ -28,6 +28,5 @@ if [ ${#missing_packages[@]} -ne 0 ]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
stack_dir=$(dirname $(dirname $THIS_DIR))
|
stack_dir=$(dirname $(dirname $THIS_DIR))
|
||||||
models_dir=$(dirname $stack_dir)/llama-models
|
PYTHONPATH=$PYTHONPATH:$stack_dir \
|
||||||
PYTHONPATH=$PYTHONPATH:$stack_dir:$models_dir \
|
|
||||||
python -m docs.openapi_generator.generate $(dirname $THIS_DIR)/_static
|
python -m docs.openapi_generator.generate $(dirname $THIS_DIR)/_static
|
||||||
|
|
91
docs/source/building_applications/agent.md
Normal file
91
docs/source/building_applications/agent.md
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
# Llama Stack Agent Framework
|
||||||
|
|
||||||
|
The Llama Stack agent framework is built on a modular architecture that allows for flexible and powerful AI applications. This document explains the key components and how they work together.
|
||||||
|
|
||||||
|
## Core Concepts
|
||||||
|
|
||||||
|
### 1. Agent Configuration
|
||||||
|
|
||||||
|
Agents are configured using the `AgentConfig` class, which includes:
|
||||||
|
|
||||||
|
- **Model**: The underlying LLM to power the agent
|
||||||
|
- **Instructions**: System prompt that defines the agent's behavior
|
||||||
|
- **Tools**: Capabilities the agent can use to interact with external systems
|
||||||
|
- **Safety Shields**: Guardrails to ensure responsible AI behavior
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
|
|
||||||
|
# Configure an agent
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model="meta-llama/Llama-3-70b-chat",
|
||||||
|
instructions="You are a helpful assistant that can use tools to answer questions.",
|
||||||
|
toolgroups=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the agent
|
||||||
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Sessions
|
||||||
|
|
||||||
|
Agents maintain state through sessions, which represent a conversation thread:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create a session
|
||||||
|
session_id = agent.create_session(session_name="My conversation")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Turns
|
||||||
|
|
||||||
|
Each interaction with an agent is called a "turn" and consists of:
|
||||||
|
|
||||||
|
- **Input Messages**: What the user sends to the agent
|
||||||
|
- **Steps**: The agent's internal processing (inference, tool execution, etc.)
|
||||||
|
- **Output Message**: The agent's response
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
|
|
||||||
|
# Create a turn with streaming response
|
||||||
|
turn_response = agent.create_turn(
|
||||||
|
session_id=session_id,
|
||||||
|
messages=[{"role": "user", "content": "Tell me about Llama models"}],
|
||||||
|
)
|
||||||
|
for log in EventLogger().log(turn_response):
|
||||||
|
log.print()
|
||||||
|
```
|
||||||
|
### Non-Streaming
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
# Non-streaming API
|
||||||
|
response = agent.create_turn(
|
||||||
|
session_id=session_id,
|
||||||
|
messages=[{"role": "user", "content": "Tell me about Llama models"}],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
print("Inputs:")
|
||||||
|
pprint(response.input_messages)
|
||||||
|
print("Output:")
|
||||||
|
pprint(response.output_message.content)
|
||||||
|
print("Steps:")
|
||||||
|
pprint(response.steps)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Steps
|
||||||
|
|
||||||
|
Each turn consists of multiple steps that represent the agent's thought process:
|
||||||
|
|
||||||
|
- **Inference Steps**: The agent generating text responses
|
||||||
|
- **Tool Execution Steps**: The agent using tools to gather information
|
||||||
|
- **Shield Call Steps**: Safety checks being performed
|
||||||
|
|
||||||
|
## Agent Execution Loop
|
||||||
|
|
||||||
|
|
||||||
|
Refer to the [Agent Execution Loop](agent_execution_loop) for more details on what happens within an agent turn.
|
|
@ -13,7 +13,7 @@ Each agent turn follows these key steps:
|
||||||
|
|
||||||
3. **Inference Loop**: The agent enters its main execution loop:
|
3. **Inference Loop**: The agent enters its main execution loop:
|
||||||
- The LLM receives a user prompt (with previous tool outputs)
|
- The LLM receives a user prompt (with previous tool outputs)
|
||||||
- The LLM generates a response, potentially with tool calls
|
- The LLM generates a response, potentially with [tool calls](tools)
|
||||||
- If tool calls are present:
|
- If tool calls are present:
|
||||||
- Tool inputs are safety-checked
|
- Tool inputs are safety-checked
|
||||||
- Tools are executed (e.g., web search, code execution)
|
- Tools are executed (e.g., web search, code execution)
|
||||||
|
@ -67,9 +67,17 @@ sequenceDiagram
|
||||||
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
|
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
# Replace host and port
|
||||||
|
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
|
# Check with `llama-stack-client models list`
|
||||||
model="Llama3.2-3B-Instruct",
|
model="Llama3.2-3B-Instruct",
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
# Enable both RAG and tool usage
|
# Enable both RAG and tool usage
|
||||||
|
@ -80,7 +88,7 @@ agent_config = AgentConfig(
|
||||||
},
|
},
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
],
|
],
|
||||||
# Configure safety
|
# Configure safety (optional)
|
||||||
input_shields=["llama_guard"],
|
input_shields=["llama_guard"],
|
||||||
output_shields=["llama_guard"],
|
output_shields=["llama_guard"],
|
||||||
# Control the inference loop
|
# Control the inference loop
|
||||||
|
@ -97,7 +105,7 @@ session_id = agent.create_session("monitored_session")
|
||||||
# Stream the agent's execution steps
|
# Stream the agent's execution steps
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
messages=[{"role": "user", "content": "Analyze this code and run it"}],
|
messages=[{"role": "user", "content": "Analyze this code and run it"}],
|
||||||
attachments=[
|
documents=[
|
||||||
{
|
{
|
||||||
"content": "https://raw.githubusercontent.com/example/code.py",
|
"content": "https://raw.githubusercontent.com/example/code.py",
|
||||||
"mime_type": "text/plain",
|
"mime_type": "text/plain",
|
||||||
|
@ -108,14 +116,21 @@ response = agent.create_turn(
|
||||||
|
|
||||||
# Monitor each step of execution
|
# Monitor each step of execution
|
||||||
for log in EventLogger().log(response):
|
for log in EventLogger().log(response):
|
||||||
if log.event.step_type == "memory_retrieval":
|
log.print()
|
||||||
print("Retrieved context:", log.event.retrieved_context)
|
|
||||||
elif log.event.step_type == "inference":
|
# Using non-streaming API, the response contains input, steps, and output.
|
||||||
print("LLM output:", log.event.model_response)
|
response = agent.create_turn(
|
||||||
elif log.event.step_type == "tool_execution":
|
messages=[{"role": "user", "content": "Analyze this code and run it"}],
|
||||||
print("Tool call:", log.event.tool_call)
|
documents=[
|
||||||
print("Tool response:", log.event.tool_response)
|
{
|
||||||
elif log.event.step_type == "shield_call":
|
"content": "https://raw.githubusercontent.com/example/code.py",
|
||||||
if log.event.violation:
|
"mime_type": "text/plain",
|
||||||
print("Safety violation:", log.event.violation)
|
}
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
pprint(f"Input: {response.input_messages}")
|
||||||
|
pprint(f"Output: {response.output_message.content}")
|
||||||
|
pprint(f"Steps: {response.steps}")
|
||||||
```
|
```
|
||||||
|
|
|
@ -149,7 +149,6 @@ agent_config = {
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"tool_choice": "auto",
|
"tool_choice": "auto",
|
||||||
"tool_prompt_format": "json",
|
|
||||||
"input_shields": [],
|
"input_shields": [],
|
||||||
"output_shields": [],
|
"output_shields": [],
|
||||||
"enable_session_persistence": False,
|
"enable_session_persistence": False,
|
||||||
|
|
|
@ -8,22 +8,24 @@ The best way to get started is to look at this notebook which walks through the
|
||||||
|
|
||||||
Here are some key topics that will help you build effective agents:
|
Here are some key topics that will help you build effective agents:
|
||||||
|
|
||||||
- **[Agent Execution Loop](agent_execution_loop)**
|
- **[Agent](agent)**: Understand the components and design patterns of the Llama Stack agent framework.
|
||||||
- **[RAG](rag)**
|
- **[Agent Execution Loop](agent_execution_loop)**: Understand how agents process information, make decisions, and execute actions in a continuous loop.
|
||||||
- **[Safety](safety)**
|
- **[RAG (Retrieval-Augmented Generation)](rag)**: Learn how to enhance your agents with external knowledge through retrieval mechanisms.
|
||||||
- **[Tools](tools)**
|
- **[Tools](tools)**: Extend your agents' capabilities by integrating with external tools and APIs.
|
||||||
- **[Telemetry](telemetry)**
|
- **[Evals](evals)**: Evaluate your agents' effectiveness and identify areas for improvement.
|
||||||
- **[Evals](evals)**
|
- **[Telemetry](telemetry)**: Monitor and analyze your agents' performance and behavior.
|
||||||
|
- **[Safety](safety)**: Implement guardrails and safety measures to ensure responsible AI behavior.
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:hidden:
|
:hidden:
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
|
agent
|
||||||
agent_execution_loop
|
agent_execution_loop
|
||||||
rag
|
rag
|
||||||
safety
|
|
||||||
tools
|
tools
|
||||||
telemetry
|
telemetry
|
||||||
evals
|
evals
|
||||||
|
advanced_agent_patterns
|
||||||
|
safety
|
||||||
```
|
```
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
## Using "Memory" or Retrieval Augmented Generation (RAG)
|
## Using Retrieval Augmented Generation (RAG)
|
||||||
|
|
||||||
Memory enables your applications to reference and recall information from previous interactions or external documents.
|
RAG enables your applications to reference and recall information from previous interactions or external documents.
|
||||||
|
|
||||||
Llama Stack organizes the memory APIs into three layers:
|
Llama Stack organizes the APIs that enable RAG into three layers:
|
||||||
- the lowermost APIs deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon.)
|
- the lowermost APIs deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon.)
|
||||||
- next is the "Rag Tool", a first-class tool as part of the Tools API that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly.
|
- next is the "Rag Tool", a first-class tool as part of the Tools API that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly.
|
||||||
- finally, it all comes together with the top-level "Agents" API that allows you to create agents that can use the tools to answer questions, perform tasks, and more.
|
- finally, it all comes together with the top-level "Agents" API that allows you to create agents that can use the tools to answer questions, perform tasks, and more.
|
||||||
|
@ -86,7 +86,7 @@ from llama_stack_client.lib.agents.agent import Agent
|
||||||
|
|
||||||
# Configure agent with memory
|
# Configure agent with memory
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
model="meta-llama/Llama-3.3-70B-Instruct",
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
enable_session_persistence=False,
|
enable_session_persistence=False,
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
|
@ -102,6 +102,19 @@ agent_config = AgentConfig(
|
||||||
agent = Agent(client, agent_config)
|
agent = Agent(client, agent_config)
|
||||||
session_id = agent.create_session("rag_session")
|
session_id = agent.create_session("rag_session")
|
||||||
|
|
||||||
|
|
||||||
|
# Ask questions about documents in the vector db, and the agent will query the db to answer the question.
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[{"role": "user", "content": "How to optimize memory in PyTorch?"}],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
> **NOTE:** the `instructions` field in the `AgentConfig` can be used to guide the agent's behavior. It is important to experiment with different instructions to see what works best for your use case.
|
||||||
|
|
||||||
|
|
||||||
|
You can also pass documents along with the user's message and ask questions about them.
|
||||||
|
```python
|
||||||
# Initial document ingestion
|
# Initial document ingestion
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
messages=[
|
messages=[
|
||||||
|
|
|
@ -83,15 +83,15 @@ result = client.tool_runtime.invoke_tool(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Memory
|
#### RAG
|
||||||
|
|
||||||
The Memory tool enables retrieval of context from various types of memory banks (vector, key-value, keyword, and graph).
|
The RAG tool enables retrieval of context from various types of memory banks (vector, key-value, keyword, and graph).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Register Memory tool group
|
# Register Memory tool group
|
||||||
client.toolgroups.register(
|
client.toolgroups.register(
|
||||||
toolgroup_id="builtin::memory",
|
toolgroup_id="builtin::rag",
|
||||||
provider_id="memory",
|
provider_id="faiss",
|
||||||
args={"max_chunks": 5, "max_tokens_in_context": 4096},
|
args={"max_chunks": 5, "max_tokens_in_context": 4096},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
@ -102,7 +102,7 @@ Features:
|
||||||
- Context retrieval with token limits
|
- Context retrieval with token limits
|
||||||
|
|
||||||
|
|
||||||
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, code interpreter and memory, that are provided by tavily-search, code-interpreter and memory providers.
|
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, code interpreter and rag, that are provided by tavily-search, code-interpreter and rag providers.
|
||||||
|
|
||||||
## Model Context Protocol (MCP) Tools
|
## Model Context Protocol (MCP) Tools
|
||||||
|
|
||||||
|
@ -125,50 +125,43 @@ MCP tools require:
|
||||||
- Tools are discovered dynamically from the endpoint
|
- Tools are discovered dynamically from the endpoint
|
||||||
|
|
||||||
|
|
||||||
## Tools provided by the client
|
## Adding Custom Tools
|
||||||
|
|
||||||
These tools are registered along with the agent config and are specific to the agent for which they are registered. The main difference between these tools and the tools provided by the built-in providers is that the execution of these tools is handled by the client and the agent transfers the tool call to the client and waits for the result from the client.
|
When you want to use tools other than the built-in tools, you can implement a python function and decorate it with `@client_tool`.
|
||||||
|
|
||||||
|
To define a custom tool, you need to use the `@client_tool` decorator.
|
||||||
|
```python
|
||||||
|
from llama_stack_client.lib.agents.client_tool import client_tool
|
||||||
|
|
||||||
|
|
||||||
|
# Example tool definition
|
||||||
|
@client_tool
|
||||||
|
def my_tool(input: int) -> int:
|
||||||
|
"""
|
||||||
|
Runs my awesome tool.
|
||||||
|
|
||||||
|
:param input: some int parameter
|
||||||
|
"""
|
||||||
|
return input * 2
|
||||||
|
```
|
||||||
|
> **NOTE:** We employ python docstrings to describe the tool and the parameters. It is important to document the tool and the parameters so that the model can use the tool correctly. It is recommended to experiment with different docstrings to see how they affect the model's behavior.
|
||||||
|
|
||||||
|
Once defined, simply pass the tool to the agent config. `Agent` will take care of the rest (calling the model with the tool definition, executing the tool, and returning the result to the model for the next iteration).
|
||||||
```python
|
```python
|
||||||
# Example agent config with client provided tools
|
# Example agent config with client provided tools
|
||||||
config = AgentConfig(
|
client_tools = [
|
||||||
toolgroups=[
|
my_tool,
|
||||||
"builtin::websearch",
|
]
|
||||||
],
|
|
||||||
client_tools=[ToolDef(name="client_tool", description="Client provided tool")],
|
agent_config = AgentConfig(
|
||||||
|
...,
|
||||||
|
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
|
||||||
)
|
)
|
||||||
|
agent = Agent(client, agent_config, client_tools)
|
||||||
```
|
```
|
||||||
|
|
||||||
Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools.
|
Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools.
|
||||||
|
|
||||||
## Tool Structure
|
|
||||||
|
|
||||||
Each tool has the following components:
|
|
||||||
|
|
||||||
- `name`: Unique identifier for the tool
|
|
||||||
- `description`: Human-readable description of the tool's functionality
|
|
||||||
- `parameters`: List of parameters the tool accepts
|
|
||||||
- `name`: Parameter name
|
|
||||||
- `parameter_type`: Data type (string, number, etc.)
|
|
||||||
- `description`: Parameter description
|
|
||||||
- `required`: Whether the parameter is required (default: true)
|
|
||||||
- `default`: Default value if any
|
|
||||||
|
|
||||||
Example tool definition:
|
|
||||||
```python
|
|
||||||
{
|
|
||||||
"name": "web_search",
|
|
||||||
"description": "Search the web for information",
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "query",
|
|
||||||
"parameter_type": "string",
|
|
||||||
"description": "The query to search for",
|
|
||||||
"required": True,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Tool Invocation
|
## Tool Invocation
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ Build a Llama stack container
|
||||||
|
|
||||||
options:
|
options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml.
|
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml.
|
||||||
If this argument is not provided, you will be prompted to enter information interactively
|
If this argument is not provided, you will be prompted to enter information interactively
|
||||||
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates
|
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates
|
||||||
--list-templates Show the available templates for building a Llama Stack distribution
|
--list-templates Show the available templates for building a Llama Stack distribution
|
||||||
|
@ -106,7 +106,7 @@ It would be best to start with a template and understand the structure of the co
|
||||||
llama stack build
|
llama stack build
|
||||||
|
|
||||||
> Enter a name for your Llama Stack (e.g. my-local-stack): my-stack
|
> Enter a name for your Llama Stack (e.g. my-local-stack): my-stack
|
||||||
> Enter the image type you want your Llama Stack to be built as (container or conda): conda
|
> Enter the image type you want your Llama Stack to be built as (container or conda or venv): conda
|
||||||
|
|
||||||
Llama Stack is composed of several APIs working together. Let's select
|
Llama Stack is composed of several APIs working together. Let's select
|
||||||
the provider types (implementations) you want to use for these APIs.
|
the provider types (implementations) you want to use for these APIs.
|
||||||
|
@ -187,7 +187,7 @@ usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-i
|
||||||
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
|
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
|
||||||
config
|
config
|
||||||
|
|
||||||
start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
|
Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
|
||||||
|
|
||||||
positional arguments:
|
positional arguments:
|
||||||
config Path to config file to use for the run
|
config Path to config file to use for the run
|
||||||
|
|
|
@ -27,16 +27,19 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3-8B-Instruct (meta/llama3-8b-instruct)`
|
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3-70B-Instruct (meta/llama3-70b-instruct)`
|
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (meta/llama-3.1-8b-instruct)`
|
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-70B-Instruct (meta/llama-3.1-70b-instruct)`
|
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (meta/llama-3.1-405b-instruct)`
|
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
- `meta-llama/Llama-3.2-1B-Instruct (meta/llama-3.2-1b-instruct)`
|
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-3B-Instruct (meta/llama-3.2-3b-instruct)`
|
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct (meta/llama-3.2-11b-vision-instruct)`
|
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct (meta/llama-3.2-90b-vision-instruct)`
|
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
- `baai/bge-m3 (baai/bge-m3)`
|
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
||||||
|
- `nvidia/nv-embedqa-e5-v5 `
|
||||||
|
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
||||||
|
- `snowflake/arctic-embed-l `
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -34,9 +34,9 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (meta.llama3-1-8b-instruct-v1:0)`
|
- `meta.llama3-1-8b-instruct-v1:0 (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-70B-Instruct (meta.llama3-1-70b-instruct-v1:0)`
|
- `meta.llama3-1-70b-instruct-v1:0 (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (meta.llama3-1-405b-instruct-v1:0)`
|
- `meta.llama3-1-405b-instruct-v1:0 (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -27,8 +27,8 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (llama3.1-8b)`
|
- `llama3.1-8b (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.3-70B-Instruct (llama-3.3-70b)`
|
- `llama-3.3-70b (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -37,17 +37,17 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (accounts/fireworks/models/llama-v3p1-8b-instruct)`
|
- `accounts/fireworks/models/llama-v3p1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-70B-Instruct (accounts/fireworks/models/llama-v3p1-70b-instruct)`
|
- `accounts/fireworks/models/llama-v3p1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (accounts/fireworks/models/llama-v3p1-405b-instruct)`
|
- `accounts/fireworks/models/llama-v3p1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
- `meta-llama/Llama-3.2-1B-Instruct (accounts/fireworks/models/llama-v3p2-1b-instruct)`
|
- `accounts/fireworks/models/llama-v3p2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-3B-Instruct (accounts/fireworks/models/llama-v3p2-3b-instruct)`
|
- `accounts/fireworks/models/llama-v3p2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct (accounts/fireworks/models/llama-v3p2-11b-vision-instruct)`
|
- `accounts/fireworks/models/llama-v3p2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct (accounts/fireworks/models/llama-v3p2-90b-vision-instruct)`
|
- `accounts/fireworks/models/llama-v3p2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-3.3-70B-Instruct (accounts/fireworks/models/llama-v3p3-70b-instruct)`
|
- `accounts/fireworks/models/llama-v3p3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `meta-llama/Llama-Guard-3-8B (accounts/fireworks/models/llama-guard-3-8b)`
|
- `accounts/fireworks/models/llama-guard-3-8b (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||||
- `meta-llama/Llama-Guard-3-11B-Vision (accounts/fireworks/models/llama-guard-3-11b-vision)`
|
- `accounts/fireworks/models/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||||
- `nomic-ai/nomic-embed-text-v1.5 (nomic-ai/nomic-embed-text-v1.5)`
|
- `nomic-ai/nomic-embed-text-v1.5 `
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -37,11 +37,11 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (llama3-8b-8192)`
|
- `groq/llama3-8b-8192 (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (llama-3.1-8b-instant)`
|
- `groq/llama-3.1-8b-instant `
|
||||||
- `meta-llama/Llama-3-70B-Instruct (llama3-70b-8192)`
|
- `groq/llama3-70b-8192 (aliases: meta-llama/Llama-3-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.3-70B-Instruct (llama-3.3-70b-versatile)`
|
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-3B-Instruct (llama-3.2-3b-preview)`
|
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -41,12 +41,31 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
## Prerequisite: Downloading Models
|
## Prerequisite: Downloading Models
|
||||||
|
|
||||||
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ls ~/.llama/checkpoints
|
$ llama model list --downloaded
|
||||||
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B
|
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
||||||
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M
|
┃ Model ┃ Size ┃ Modified Time ┃
|
||||||
|
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
||||||
|
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
||||||
|
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
## Running the Distribution
|
## Running the Distribution
|
||||||
|
|
|
@ -41,12 +41,31 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
## Prerequisite: Downloading Models
|
## Prerequisite: Downloading Models
|
||||||
|
|
||||||
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ls ~/.llama/checkpoints
|
$ llama model list --downloaded
|
||||||
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B
|
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
||||||
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M
|
┃ Model ┃ Size ┃ Modified Time ┃
|
||||||
|
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
||||||
|
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
||||||
|
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
## Running the Distribution
|
## Running the Distribution
|
||||||
|
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,17 +141,21 @@ ollama run <model_name>
|
||||||
To make sure that the model is being served correctly, run `ollama ps` to get a list of models being served by ollama.
|
To make sure that the model is being served correctly, run `ollama ps` to get a list of models being served by ollama.
|
||||||
```
|
```
|
||||||
$ ollama ps
|
$ ollama ps
|
||||||
|
NAME ID SIZE PROCESSOR UNTIL
|
||||||
NAME ID SIZE PROCESSOR UNTIL
|
llama3.2:3b-instruct-fp16 195a8c01d91e 8.6 GB 100% GPU 9 minutes from now
|
||||||
llama3.1:8b-instruct-fp16 4aacac419454 17 GB 100% GPU 4 minutes from now
|
|
||||||
```
|
```
|
||||||
|
|
||||||
To verify that the model served by ollama is correctly connected to Llama Stack server
|
To verify that the model served by ollama is correctly connected to Llama Stack server
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client models list
|
$ llama-stack-client models list
|
||||||
+----------------------+----------------------+---------------+-----------------------------------------------+
|
|
||||||
| identifier | llama_model | provider_id | metadata |
|
Available Models
|
||||||
+======================+======================+===============+===============================================+
|
|
||||||
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | ollama0 | {'ollama_model': 'llama3.1:8b-instruct-fp16'} |
|
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┓
|
||||||
+----------------------+----------------------+---------------+-----------------------------------------------+
|
┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
|
||||||
|
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━┩
|
||||||
|
│ llm │ meta-llama/Llama-3.2-3B-Instruct │ llama3.2:3b-instruct-fp16 │ │ ollama │
|
||||||
|
└──────────────┴──────────────────────────────────────┴──────────────────────────────┴───────────┴─────────────┘
|
||||||
|
|
||||||
|
Total models: 1
|
||||||
```
|
```
|
||||||
|
|
|
@ -34,15 +34,15 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (Meta-Llama-3.1-8B-Instruct)`
|
- `Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-70B-Instruct (Meta-Llama-3.1-70B-Instruct)`
|
- `Meta-Llama-3.1-70B-Instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (Meta-Llama-3.1-405B-Instruct)`
|
- `Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
- `meta-llama/Llama-3.2-1B-Instruct (Meta-Llama-3.2-1B-Instruct)`
|
- `Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-3B-Instruct (Meta-Llama-3.2-3B-Instruct)`
|
- `Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
- `meta-llama/Llama-3.3-70B-Instruct (Meta-Llama-3.3-70B-Instruct)`
|
- `Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct (Llama-3.2-11B-Vision-Instruct)`
|
- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct (Llama-3.2-90B-Vision-Instruct)`
|
- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-Guard-3-8B (Meta-Llama-Guard-3-8B)`
|
- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -37,17 +37,17 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct`
|
- `meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-70B-Instruct`
|
- `meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-405B-Instruct-FP8`
|
- `meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
- `meta-llama/Llama-3.2-3B-Instruct`
|
- `meta-llama/Llama-3.2-3B-Instruct-Turbo (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct`
|
- `meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct`
|
- `meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-3.3-70B-Instruct`
|
- `meta-llama/Llama-3.3-70B-Instruct-Turbo (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `meta-llama/Llama-Guard-3-8B`
|
- `meta-llama/Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||||
- `meta-llama/Llama-Guard-3-11B-Vision`
|
- `meta-llama/Llama-Guard-3-11B-Vision-Turbo (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||||
- `togethercomputer/m2-bert-80M-8k-retrieval`
|
- `togethercomputer/m2-bert-80M-8k-retrieval `
|
||||||
- `togethercomputer/m2-bert-80M-32k-retrieval`
|
- `togethercomputer/m2-bert-80M-32k-retrieval `
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -38,7 +38,7 @@ The API is **exactly identical** for both clients.
|
||||||
:::{dropdown} Starting up the Llama Stack server
|
:::{dropdown} Starting up the Llama Stack server
|
||||||
The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc.
|
The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc.
|
||||||
|
|
||||||
To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image.
|
To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image. If you'd like to build your own image or customize the configurations, please check out [this guide](../references/index.md).
|
||||||
|
|
||||||
Lets setup some environment variables that we will use in the rest of the guide.
|
Lets setup some environment variables that we will use in the rest of the guide.
|
||||||
```bash
|
```bash
|
||||||
|
@ -102,12 +102,18 @@ Let's use the `llama-stack-client` CLI to check the connectivity to the server.
|
||||||
$ llama-stack-client configure --endpoint http://localhost:$LLAMA_STACK_PORT
|
$ llama-stack-client configure --endpoint http://localhost:$LLAMA_STACK_PORT
|
||||||
> Enter the API key (leave empty if no key is needed):
|
> Enter the API key (leave empty if no key is needed):
|
||||||
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
|
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
|
||||||
|
|
||||||
$ llama-stack-client models list
|
$ llama-stack-client models list
|
||||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
|
|
||||||
┃ identifier ┃ provider_id ┃ provider_resource_id ┃ metadata ┃
|
Available Models
|
||||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
|
|
||||||
│ meta-llama/Llama-3.2-3B-Instruct │ ollama │ llama3.2:3b-instruct-fp16 │ │
|
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┓
|
||||||
└──────────────────────────────────┴─────────────┴───────────────────────────┴──────────┘
|
┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
|
||||||
|
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━┩
|
||||||
|
│ llm │ meta-llama/Llama-3.2-3B-Instruct │ llama3.2:3b-instruct-fp16 │ │ ollama │
|
||||||
|
└──────────────┴──────────────────────────────────────┴──────────────────────────────┴───────────┴─────────────┘
|
||||||
|
|
||||||
|
Total models: 1
|
||||||
```
|
```
|
||||||
|
|
||||||
You can test basic Llama inference completion using the CLI too.
|
You can test basic Llama inference completion using the CLI too.
|
||||||
|
@ -255,7 +261,7 @@ rag_agent = Agent(client, agent_config)
|
||||||
session_id = rag_agent.create_session("test-session")
|
session_id = rag_agent.create_session("test-session")
|
||||||
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
"What are the top 5 topics that were explained? Only list succinct bullet points.",
|
"How to optimize memory usage in torchtune? use the knowledge_search tool to get information.",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Run the agent loop by calling the `create_turn` method
|
# Run the agent loop by calling the `create_turn` method
|
||||||
|
|
|
@ -129,3 +129,35 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
|
||||||
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
|
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
|
||||||
|
|
||||||
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
|
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
|
||||||
|
|
||||||
|
## List the downloaded models
|
||||||
|
|
||||||
|
To list the downloaded models with the following command:
|
||||||
|
```
|
||||||
|
llama model list --downloaded
|
||||||
|
```
|
||||||
|
|
||||||
|
You should see a table like this:
|
||||||
|
```
|
||||||
|
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
||||||
|
┃ Model ┃ Size ┃ Modified Time ┃
|
||||||
|
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
||||||
|
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
||||||
|
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
||||||
|
```
|
||||||
|
|
|
@ -154,6 +154,38 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
|
||||||
|
|
||||||
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
|
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
|
||||||
|
|
||||||
|
## List the downloaded models
|
||||||
|
|
||||||
|
To list the downloaded models with the following command:
|
||||||
|
```
|
||||||
|
llama model list --downloaded
|
||||||
|
```
|
||||||
|
|
||||||
|
You should see a table like this:
|
||||||
|
```
|
||||||
|
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
||||||
|
┃ Model ┃ Size ┃ Modified Time ┃
|
||||||
|
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
||||||
|
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
||||||
|
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Understand the models
|
## Understand the models
|
||||||
The `llama model` command helps you explore the model’s interface.
|
The `llama model` command helps you explore the model’s interface.
|
||||||
|
|
|
@ -58,11 +58,15 @@ llama-stack-client providers list
|
||||||
llama-stack-client models list
|
llama-stack-client models list
|
||||||
```
|
```
|
||||||
```
|
```
|
||||||
+----------------------+----------------------+---------------+----------------------------------------------------------+
|
Available Models
|
||||||
| identifier | llama_model | provider_id | metadata |
|
|
||||||
+======================+======================+===============+==========================================================+
|
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┓
|
||||||
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | tgi0 | {'huggingface_repo': 'meta-llama/Llama-3.1-8B-Instruct'} |
|
┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
|
||||||
+----------------------+----------------------+---------------+----------------------------------------------------------+
|
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━┩
|
||||||
|
│ llm │ meta-llama/Llama-3.2-3B-Instruct │ llama3.2:3b-instruct-fp16 │ │ ollama │
|
||||||
|
└──────────────┴──────────────────────────────────────┴──────────────────────────────┴───────────┴─────────────┘
|
||||||
|
|
||||||
|
Total models: 1
|
||||||
```
|
```
|
||||||
|
|
||||||
### `llama-stack-client models get`
|
### `llama-stack-client models get`
|
||||||
|
|
|
@ -73,7 +73,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
||||||
Open a new terminal and install `llama-stack`:
|
Open a new terminal and install `llama-stack`:
|
||||||
```bash
|
```bash
|
||||||
conda activate ollama
|
conda activate ollama
|
||||||
pip install llama-stack==0.1.0
|
pip install -U llama-stack
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
|
@ -5,6 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -33,3 +36,20 @@ class Api(Enum):
|
||||||
|
|
||||||
# built-in API
|
# built-in API
|
||||||
inspect = "inspect"
|
inspect = "inspect"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Error(BaseModel):
|
||||||
|
"""
|
||||||
|
Error response from the API. Roughly follows RFC 7807.
|
||||||
|
|
||||||
|
:param status: HTTP status code
|
||||||
|
:param title: Error title, a short summary of the error which is invariant for an error type
|
||||||
|
:param detail: Error detail, a longer human-readable description of the error
|
||||||
|
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: int
|
||||||
|
title: str
|
||||||
|
detail: str
|
||||||
|
instance: Optional[str] = None
|
||||||
|
|
|
@ -9,6 +9,7 @@ import argparse
|
||||||
from .download import Download
|
from .download import Download
|
||||||
from .model import ModelParser
|
from .model import ModelParser
|
||||||
from .stack import StackParser
|
from .stack import StackParser
|
||||||
|
from .stack.utils import print_subcommand_description
|
||||||
from .verify_download import VerifyDownload
|
from .verify_download import VerifyDownload
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,6 +21,7 @@ class LlamaCLIParser:
|
||||||
prog="llama",
|
prog="llama",
|
||||||
description="Welcome to the Llama CLI",
|
description="Welcome to the Llama CLI",
|
||||||
add_help=True,
|
add_help=True,
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Default command is to print help
|
# Default command is to print help
|
||||||
|
@ -33,6 +35,8 @@ class LlamaCLIParser:
|
||||||
Download.create(subparsers)
|
Download.create(subparsers)
|
||||||
VerifyDownload.create(subparsers)
|
VerifyDownload.create(subparsers)
|
||||||
|
|
||||||
|
print_subcommand_description(self.parser, subparsers)
|
||||||
|
|
||||||
def parse_args(self) -> argparse.Namespace:
|
def parse_args(self) -> argparse.Namespace:
|
||||||
return self.parser.parse_args()
|
return self.parser.parse_args()
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ from llama_stack.cli.model.list import ModelList
|
||||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||||
from llama_stack.cli.model.remove import ModelRemove
|
from llama_stack.cli.model.remove import ModelRemove
|
||||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||||
|
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,6 +25,7 @@ class ModelParser(Subcommand):
|
||||||
"model",
|
"model",
|
||||||
prog="llama model",
|
prog="llama model",
|
||||||
description="Work with llama models",
|
description="Work with llama models",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||||
|
@ -37,3 +39,5 @@ class ModelParser(Subcommand):
|
||||||
ModelDescribe.create(subparsers)
|
ModelDescribe.create(subparsers)
|
||||||
ModelVerifyDownload.create(subparsers)
|
ModelVerifyDownload.create(subparsers)
|
||||||
ModelRemove.create(subparsers)
|
ModelRemove.create(subparsers)
|
||||||
|
|
||||||
|
print_subcommand_description(self.parser, subparsers)
|
||||||
|
|
|
@ -7,10 +7,14 @@
|
||||||
import argparse
|
import argparse
|
||||||
import textwrap
|
import textwrap
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.cli.table import print_table
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
|
||||||
class ModelPromptFormat(Subcommand):
|
class ModelPromptFormat(Subcommand):
|
||||||
"""Llama model cli for describe a model prompt format (message formats)"""
|
"""Llama model cli for describe a model prompt format (message formats)"""
|
||||||
|
@ -48,7 +52,26 @@ class ModelPromptFormat(Subcommand):
|
||||||
supported_model_ids = [
|
supported_model_ids = [
|
||||||
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||||
]
|
]
|
||||||
model_str = "\n".join([m.value for m in supported_model_ids])
|
|
||||||
|
model_list = [m.value for m in supported_model_ids]
|
||||||
|
model_str = "\n".join(model_list)
|
||||||
|
|
||||||
|
if args.list:
|
||||||
|
headers = ["Model(s)"]
|
||||||
|
rows = []
|
||||||
|
for m in model_list:
|
||||||
|
rows.append(
|
||||||
|
[
|
||||||
|
m,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print_table(
|
||||||
|
rows,
|
||||||
|
headers,
|
||||||
|
separate_rows=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_id = CoreModelId(args.model_name)
|
model_id = CoreModelId(args.model_name)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -57,9 +80,9 @@ class ModelPromptFormat(Subcommand):
|
||||||
if model_id not in supported_model_ids:
|
if model_id not in supported_model_ids:
|
||||||
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
||||||
|
|
||||||
llama_3_1_file = importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
|
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
|
||||||
llama_3_2_text_file = importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
|
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
|
||||||
llama_3_2_vision_file = importlib.resources.files("llama_models") / "llama3_2/vision_prompt_format.md"
|
llama_3_2_vision_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "vision_prompt_format.md"
|
||||||
if model_family(model_id) == ModelFamily.llama3_1:
|
if model_family(model_id) == ModelFamily.llama3_1:
|
||||||
with importlib.resources.as_file(llama_3_1_file) as f:
|
with importlib.resources.as_file(llama_3_1_file) as f:
|
||||||
content = f.open("r").read()
|
content = f.open("r").read()
|
||||||
|
|
|
@ -38,7 +38,7 @@ from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.distribution.utils.exec import formulate_run_args, in_notebook, run_with_pty
|
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||||
from llama_stack.distribution.utils.image_types import ImageType
|
from llama_stack.distribution.utils.image_types import ImageType
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
@ -65,8 +65,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
if args.image_type == "venv":
|
if args.image_type == "venv":
|
||||||
current_venv = os.environ.get("VIRTUAL_ENV")
|
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||||
image_name = args.image_name or current_venv
|
image_name = args.image_name or current_venv
|
||||||
if not image_name and in_notebook():
|
|
||||||
image_name = "__system__"
|
|
||||||
elif args.image_type == "conda":
|
elif args.image_type == "conda":
|
||||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||||
image_name = args.image_name or current_conda_env
|
image_name = args.image_name or current_conda_env
|
||||||
|
@ -143,7 +141,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
completer=WordCompleter(available_providers),
|
completer=WordCompleter(available_providers),
|
||||||
complete_while_typing=True,
|
complete_while_typing=True,
|
||||||
validator=Validator.from_callable(
|
validator=Validator.from_callable(
|
||||||
lambda x: x in available_providers,
|
lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
|
||||||
error_message="Invalid provider, use <TAB> to see options",
|
error_message="Invalid provider, use <TAB> to see options",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -291,6 +289,8 @@ def _run_stack_build_command_from_build_config(
|
||||||
if not image_name:
|
if not image_name:
|
||||||
raise ValueError("Please specify an image name when building a conda image")
|
raise ValueError("Please specify an image name when building a conda image")
|
||||||
elif build_config.image_type == ImageType.venv.value:
|
elif build_config.image_type == ImageType.venv.value:
|
||||||
|
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
|
||||||
|
image_name = "__system__"
|
||||||
if not image_name:
|
if not image_name:
|
||||||
raise ValueError("Please specify an image name when building a venv image")
|
raise ValueError("Please specify an image name when building a venv image")
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ class StackBuild(Subcommand):
|
||||||
"--config",
|
"--config",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
|
help="Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
|
||||||
|
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
from .build import StackBuild
|
from .build import StackBuild
|
||||||
|
@ -22,6 +23,7 @@ class StackParser(Subcommand):
|
||||||
"stack",
|
"stack",
|
||||||
prog="llama stack",
|
prog="llama stack",
|
||||||
description="Operations for the Llama Stack / Distributions",
|
description="Operations for the Llama Stack / Distributions",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
|
@ -39,3 +41,5 @@ class StackParser(Subcommand):
|
||||||
StackListApis.create(subparsers)
|
StackListApis.create(subparsers)
|
||||||
StackListProviders.create(subparsers)
|
StackListProviders.create(subparsers)
|
||||||
StackRun.create(subparsers)
|
StackRun.create(subparsers)
|
||||||
|
|
||||||
|
print_subcommand_description(self.parser, subparsers)
|
||||||
|
|
14
llama_stack/cli/stack/utils.py
Normal file
14
llama_stack/cli/stack/utils.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
def print_subcommand_description(parser, subparsers):
|
||||||
|
"""Print descriptions of subcommands."""
|
||||||
|
description_text = ""
|
||||||
|
for name, subcommand in subparsers.choices.items():
|
||||||
|
description = subcommand.description
|
||||||
|
description_text += f" {name:<21} {description}\n"
|
||||||
|
parser.epilog = description_text
|
|
@ -112,7 +112,7 @@ def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
||||||
|
|
||||||
inference_providers = result.providers["inference"]
|
inference_providers = result.providers["inference"]
|
||||||
assert len(inference_providers) == 2
|
assert len(inference_providers) == 2
|
||||||
assert set(x.provider_id for x in inference_providers) == {
|
assert {x.provider_id for x in inference_providers} == {
|
||||||
"remote::ollama-00",
|
"remote::ollama-00",
|
||||||
"meta-reference-01",
|
"meta-reference-01",
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,6 @@ from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
|
||||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||||
from llama_stack.distribution.utils.image_types import ImageType
|
from llama_stack.distribution.utils.image_types import ImageType
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
@ -103,8 +102,6 @@ def build_image(
|
||||||
template_or_config,
|
template_or_config,
|
||||||
image_name,
|
image_name,
|
||||||
container_base,
|
container_base,
|
||||||
str(build_file_path),
|
|
||||||
str(BUILDS_BASE_DIR / ImageType.container.value),
|
|
||||||
" ".join(normal_deps),
|
" ".join(normal_deps),
|
||||||
]
|
]
|
||||||
elif build_config.image_type == ImageType.conda.value:
|
elif build_config.image_type == ImageType.conda.value:
|
||||||
|
|
|
@ -6,8 +6,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||||
|
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||||
|
@ -16,8 +16,8 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||||
fi
|
fi
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$#" -lt 3 ]; then
|
if [ "$#" -lt 3 ]; then
|
||||||
|
@ -52,7 +52,7 @@ ensure_conda_env_python310() {
|
||||||
local python_version="3.10"
|
local python_version="3.10"
|
||||||
|
|
||||||
# Check if conda command is available
|
# Check if conda command is available
|
||||||
if ! command -v conda &>/dev/null; then
|
if ! is_command_available conda; then
|
||||||
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
@ -87,8 +87,6 @@ ensure_conda_env_python310() {
|
||||||
# these packages are damaged in test-pypi, so install them first
|
# these packages are damaged in test-pypi, so install them first
|
||||||
uv pip install fastapi libcst
|
uv pip install fastapi libcst
|
||||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||||
llama-models==$TEST_PYPI_VERSION \
|
|
||||||
llama-stack-client==$TEST_PYPI_VERSION \
|
|
||||||
llama-stack==$TEST_PYPI_VERSION \
|
llama-stack==$TEST_PYPI_VERSION \
|
||||||
$pip_dependencies
|
$pip_dependencies
|
||||||
if [ -n "$special_pip_deps" ]; then
|
if [ -n "$special_pip_deps" ]; then
|
||||||
|
@ -111,22 +109,21 @@ ensure_conda_env_python310() {
|
||||||
else
|
else
|
||||||
PYPI_VERSION="${PYPI_VERSION:-}"
|
PYPI_VERSION="${PYPI_VERSION:-}"
|
||||||
if [ -n "$PYPI_VERSION" ]; then
|
if [ -n "$PYPI_VERSION" ]; then
|
||||||
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}"
|
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
|
||||||
else
|
else
|
||||||
SPEC_VERSION="llama-stack"
|
SPEC_VERSION="llama-stack"
|
||||||
fi
|
fi
|
||||||
uv pip install --no-cache-dir $SPEC_VERSION
|
uv pip install --no-cache-dir $SPEC_VERSION
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
|
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
|
printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
|
||||||
uv pip uninstall llama-models
|
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||||
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Install pip dependencies
|
# Install pip dependencies
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
#!/bin/bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
@ -6,7 +6,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||||
|
|
||||||
|
@ -20,26 +19,27 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||||
# mounting is not supported by docker buildx, so we use COPY instead
|
# mounting is not supported by docker buildx, so we use COPY instead
|
||||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||||
|
|
||||||
if [ "$#" -lt 6 ]; then
|
if [ "$#" -lt 4 ]; then
|
||||||
# This only works for templates
|
# This only works for templates
|
||||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <build_file_path> <host_build_dir> <pip_dependencies> [<special_pip_deps>]" >&2
|
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
template_or_config="$1"
|
template_or_config="$1"
|
||||||
image_name="$2"
|
shift
|
||||||
container_base="$3"
|
image_name="$1"
|
||||||
build_file_path="$4"
|
shift
|
||||||
host_build_dir="$5"
|
container_base="$1"
|
||||||
pip_dependencies="$6"
|
shift
|
||||||
special_pip_deps="${7:-}"
|
pip_dependencies="$1"
|
||||||
|
shift
|
||||||
|
special_pip_deps="${1:-}"
|
||||||
|
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
GREEN='\033[0;32m'
|
|
||||||
NC='\033[0m' # No Color
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
||||||
|
@ -47,8 +47,10 @@ CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
||||||
|
|
||||||
TEMP_DIR=$(mktemp -d)
|
TEMP_DIR=$(mktemp -d)
|
||||||
|
|
||||||
|
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||||
|
source "$SCRIPT_DIR/common.sh"
|
||||||
|
|
||||||
add_to_container() {
|
add_to_container() {
|
||||||
local input
|
|
||||||
output_file="$TEMP_DIR/Containerfile"
|
output_file="$TEMP_DIR/Containerfile"
|
||||||
if [ -t 0 ]; then
|
if [ -t 0 ]; then
|
||||||
printf '%s\n' "$1" >>"$output_file"
|
printf '%s\n' "$1" >>"$output_file"
|
||||||
|
@ -58,15 +60,21 @@ add_to_container() {
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Check if container command is available
|
||||||
|
if ! is_command_available $CONTAINER_BINARY; then
|
||||||
|
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
# Update and install UBI9 components if UBI9 base image is used
|
# Update and install UBI9 components if UBI9 base image is used
|
||||||
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
FROM $container_base
|
FROM $container_base
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
RUN microdnf -y update && microdnf install -y iputils net-tools wget \
|
RUN dnf -y update && dnf install -y iputils net-tools wget \
|
||||||
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
||||||
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all
|
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
||||||
|
|
||||||
ENV UV_SYSTEM_PYTHON=1
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
RUN pip install uv
|
RUN pip install uv
|
||||||
|
@ -107,7 +115,6 @@ EOF
|
||||||
fi
|
fi
|
||||||
|
|
||||||
stack_mount="/app/llama-stack-source"
|
stack_mount="/app/llama-stack-source"
|
||||||
models_mount="/app/llama-models-source"
|
|
||||||
client_mount="/app/llama-stack-client-source"
|
client_mount="/app/llama-stack-client-source"
|
||||||
|
|
||||||
install_local_package() {
|
install_local_package() {
|
||||||
|
@ -131,10 +138,6 @@ EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|
||||||
install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
|
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
|
||||||
fi
|
fi
|
||||||
|
@ -150,12 +153,12 @@ EOF
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||||
--index-strategy unsafe-best-match \
|
--index-strategy unsafe-best-match \
|
||||||
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
llama-stack==$TEST_PYPI_VERSION
|
||||||
|
|
||||||
EOF
|
EOF
|
||||||
else
|
else
|
||||||
if [ -n "$PYPI_VERSION" ]; then
|
if [ -n "$PYPI_VERSION" ]; then
|
||||||
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}"
|
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
|
||||||
else
|
else
|
||||||
SPEC_VERSION="llama-stack"
|
SPEC_VERSION="llama-stack"
|
||||||
fi
|
fi
|
||||||
|
@ -165,6 +168,11 @@ EOF
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# remove uv after installation
|
||||||
|
add_to_container << EOF
|
||||||
|
RUN pip uninstall -y uv
|
||||||
|
EOF
|
||||||
|
|
||||||
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
||||||
if [[ "$template_or_config" != *.yaml ]]; then
|
if [[ "$template_or_config" != *.yaml ]]; then
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
|
@ -185,26 +193,28 @@ RUN mkdir -p /.llama /.cache
|
||||||
RUN chmod -R g+rw /app /.llama /.cache
|
RUN chmod -R g+rw /app /.llama /.cache
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n"
|
printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
|
||||||
cat $TEMP_DIR/Containerfile
|
cat "$TEMP_DIR"/Containerfile
|
||||||
printf "\n"
|
printf "\n"
|
||||||
|
|
||||||
mounts=""
|
# Start building the CLI arguments
|
||||||
|
CLI_ARGS=()
|
||||||
|
|
||||||
|
# Read CONTAINER_OPTS and put it in an array
|
||||||
|
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
|
||||||
|
|
||||||
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
|
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount"
|
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
|
||||||
fi
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
|
||||||
fi
|
fi
|
||||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount"
|
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
|
if is_command_available selinuxenabled && selinuxenabled; then
|
||||||
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
|
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
|
||||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
CLI_ARGS+=("--security-opt" "label=disable")
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Set version tag based on PyPI version
|
# Set version tag based on PyPI version
|
||||||
|
@ -212,7 +222,7 @@ if [ -n "$PYPI_VERSION" ]; then
|
||||||
version_tag="$PYPI_VERSION"
|
version_tag="$PYPI_VERSION"
|
||||||
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
||||||
version_tag="test-$TEST_PYPI_VERSION"
|
version_tag="test-$TEST_PYPI_VERSION"
|
||||||
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_MODELS_DIR" ]]; then
|
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_STACK_CLIENT_DIR" ]]; then
|
||||||
version_tag="dev"
|
version_tag="dev"
|
||||||
else
|
else
|
||||||
URL="https://pypi.org/pypi/llama-stack/json"
|
URL="https://pypi.org/pypi/llama-stack/json"
|
||||||
|
@ -225,11 +235,11 @@ image_tag="$image_name:$version_tag"
|
||||||
# Detect platform architecture
|
# Detect platform architecture
|
||||||
ARCH=$(uname -m)
|
ARCH=$(uname -m)
|
||||||
if [ -n "$BUILD_PLATFORM" ]; then
|
if [ -n "$BUILD_PLATFORM" ]; then
|
||||||
PLATFORM="--platform $BUILD_PLATFORM"
|
CLI_ARGS+=("--platform $BUILD_PLATFORM")
|
||||||
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||||
PLATFORM="--platform linux/arm64"
|
CLI_ARGS+=("--platform" "linux/arm64")
|
||||||
elif [ "$ARCH" = "x86_64" ]; then
|
elif [ "$ARCH" = "x86_64" ]; then
|
||||||
PLATFORM="--platform linux/amd64"
|
CLI_ARGS+=("--platform" "linux/amd64")
|
||||||
else
|
else
|
||||||
echo "Unsupported architecture: $ARCH"
|
echo "Unsupported architecture: $ARCH"
|
||||||
exit 1
|
exit 1
|
||||||
|
@ -238,8 +248,13 @@ fi
|
||||||
echo "PWD: $(pwd)"
|
echo "PWD: $(pwd)"
|
||||||
echo "Containerfile: $TEMP_DIR/Containerfile"
|
echo "Containerfile: $TEMP_DIR/Containerfile"
|
||||||
set -x
|
set -x
|
||||||
$CONTAINER_BINARY build $CONTAINER_OPTS $PLATFORM -t $image_tag \
|
|
||||||
-f "$TEMP_DIR/Containerfile" "." $mounts --progress=plain
|
$CONTAINER_BINARY build \
|
||||||
|
"${CLI_ARGS[@]}" \
|
||||||
|
-t "$image_tag" \
|
||||||
|
-f "$TEMP_DIR/Containerfile" \
|
||||||
|
"." \
|
||||||
|
--progress=plain
|
||||||
|
|
||||||
# clean up tmp/configs
|
# clean up tmp/configs
|
||||||
set +x
|
set +x
|
||||||
|
|
|
@ -9,8 +9,8 @@
|
||||||
# TODO: combine this with build_conda_env.sh since it is almost identical
|
# TODO: combine this with build_conda_env.sh since it is almost identical
|
||||||
# the only difference is that we don't do any conda-specific setup
|
# the only difference is that we don't do any conda-specific setup
|
||||||
|
|
||||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||||
|
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||||
|
@ -21,8 +21,8 @@ VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||||
fi
|
fi
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$#" -lt 2 ]; then
|
if [ "$#" -lt 2 ]; then
|
||||||
|
@ -95,7 +95,7 @@ run() {
|
||||||
# we are building a command line so word splitting is expected
|
# we are building a command line so word splitting is expected
|
||||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||||
--index-strategy unsafe-best-match \
|
--index-strategy unsafe-best-match \
|
||||||
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
|
llama-stack=="$TEST_PYPI_VERSION" \
|
||||||
$pip_dependencies
|
$pip_dependencies
|
||||||
if [ -n "$special_pip_deps" ]; then
|
if [ -n "$special_pip_deps" ]; then
|
||||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||||
|
@ -120,15 +120,14 @@ run() {
|
||||||
uv pip install --no-cache-dir llama-stack
|
uv pip install --no-cache-dir llama-stack
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_MODELS_DIR" >&2
|
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
printf "Installing from LLAMA_MODELS_DIR: %s\n" "$LLAMA_MODELS_DIR"
|
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
|
||||||
uv pip uninstall llama-models
|
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||||
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Install pip dependencies
|
# Install pip dependencies
|
||||||
|
|
|
@ -1,47 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
|
||||||
CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
error_handler() {
|
|
||||||
echo "Error occurred in script at line: ${1}" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
trap 'error_handler ${LINENO}' ERR
|
|
||||||
|
|
||||||
if [ $# -lt 2 ]; then
|
|
||||||
echo "Usage: $0 <container name> <build file path>"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
container_image="$1"
|
|
||||||
host_build_dir="$2"
|
|
||||||
container_build_dir="/app/builds"
|
|
||||||
|
|
||||||
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
|
|
||||||
# Disable SELinux labels
|
|
||||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
|
||||||
fi
|
|
||||||
|
|
||||||
mounts=""
|
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
|
|
||||||
fi
|
|
||||||
|
|
||||||
set -x
|
|
||||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
|
||||||
--entrypoint "/usr/local/bin/llama" \
|
|
||||||
-v $host_build_dir:$container_build_dir \
|
|
||||||
$mounts \
|
|
||||||
$container_image \
|
|
||||||
stack configure ./llamastack-build.yaml --output-dir $container_build_dir
|
|
|
@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> List[Api]:
|
def stack_apis() -> List[Api]:
|
||||||
return [v for v in Api]
|
return list(Api)
|
||||||
|
|
||||||
|
|
||||||
class AutoRoutedApiInfo(BaseModel):
|
class AutoRoutedApiInfo(BaseModel):
|
||||||
|
@ -55,7 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
|
|
||||||
|
|
||||||
def providable_apis() -> List[Api]:
|
def providable_apis() -> List[Api]:
|
||||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -104,7 +104,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||||
)
|
)
|
||||||
return value
|
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
|
@ -324,6 +324,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||||
|
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
status_code=httpx.codes.OK,
|
status_code=httpx.codes.OK,
|
||||||
content=json_content.encode("utf-8"),
|
content=json_content.encode("utf-8"),
|
||||||
|
@ -335,7 +336,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
url=options.url,
|
url=options.url,
|
||||||
params=options.params,
|
params=options.params,
|
||||||
headers=options.headers or {},
|
headers=options.headers or {},
|
||||||
json=options.json_data,
|
json=convert_pydantic_to_json_value(body),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response = APIResponse(
|
response = APIResponse(
|
||||||
|
@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
url=options.url,
|
url=options.url,
|
||||||
params=options.params,
|
params=options.params,
|
||||||
headers=options.headers or {},
|
headers=options.headers or {},
|
||||||
json=options.json_data,
|
json=convert_pydantic_to_json_value(body),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
|
from llama_stack import logcat
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
@ -50,8 +50,6 @@ from llama_stack.providers.datatypes import (
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidProviderError(Exception):
|
class InvalidProviderError(Exception):
|
||||||
pass
|
pass
|
||||||
|
@ -115,8 +113,8 @@ async def resolve_impls(
|
||||||
- flatmaps, sorts and resolves the providers in dependency order
|
- flatmaps, sorts and resolves the providers in dependency order
|
||||||
- for each API, produces either a (local, passthrough or router) implementation
|
- for each API, produces either a (local, passthrough or router) implementation
|
||||||
"""
|
"""
|
||||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||||
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||||
|
|
||||||
providers_with_specs = {}
|
providers_with_specs = {}
|
||||||
|
|
||||||
|
@ -127,16 +125,21 @@ async def resolve_impls(
|
||||||
|
|
||||||
specs = {}
|
specs = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
|
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||||
|
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||||
|
continue
|
||||||
|
|
||||||
if provider.provider_type not in provider_registry[api]:
|
if provider.provider_type not in provider_registry[api]:
|
||||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||||
|
|
||||||
p = provider_registry[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
if p.deprecation_error:
|
if p.deprecation_error:
|
||||||
log.error(p.deprecation_error, "red", attrs=["bold"])
|
logcat.error("core", p.deprecation_error)
|
||||||
raise InvalidProviderError(p.deprecation_error)
|
raise InvalidProviderError(p.deprecation_error)
|
||||||
|
|
||||||
elif p.deprecation_warning:
|
elif p.deprecation_warning:
|
||||||
log.warning(
|
logcat.warning(
|
||||||
|
"core",
|
||||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||||
)
|
)
|
||||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||||
|
@ -210,10 +213,10 @@ async def resolve_impls(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"Resolved {len(sorted_providers)} providers")
|
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
log.info(f" {api_str} => {provider.provider_id}")
|
logcat.debug("core", f" {api_str} => {provider.provider_id}")
|
||||||
log.info("")
|
logcat.debug("core", "")
|
||||||
|
|
||||||
impls = {}
|
impls = {}
|
||||||
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
||||||
|
@ -350,7 +353,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
obj_params = set(obj_sig.parameters)
|
obj_params = set(obj_sig.parameters)
|
||||||
obj_params.discard("self")
|
obj_params.discard("self")
|
||||||
if not (proto_params <= obj_params):
|
if not (proto_params <= obj_params):
|
||||||
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||||
missing_methods.append((name, "signature_mismatch"))
|
missing_methods.append((name, "signature_mismatch"))
|
||||||
else:
|
else:
|
||||||
# Check if the method is actually implemented in the class
|
# Check if the method is actually implemented in the class
|
||||||
|
|
|
@ -4,8 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import copy
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
|
from llama_stack import logcat
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -62,12 +64,15 @@ class VectorIORouter(VectorIO):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing VectorIORouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "VectorIORouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "VectorIORouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
|
@ -78,6 +83,7 @@ class VectorIORouter(VectorIO):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||||
await self.routing_table.register_vector_db(
|
await self.routing_table.register_vector_db(
|
||||||
vector_db_id,
|
vector_db_id,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
@ -92,6 +98,10 @@ class VectorIORouter(VectorIO):
|
||||||
chunks: List[Chunk],
|
chunks: List[Chunk],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.id for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
|
)
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
|
@ -100,6 +110,7 @@ class VectorIORouter(VectorIO):
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
|
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,12 +121,15 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "InferenceRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "InferenceRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(
|
async def register_model(
|
||||||
|
@ -126,6 +140,10 @@ class InferenceRouter(Inference):
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||||
|
)
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -141,6 +159,10 @@ class InferenceRouter(Inference):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
|
)
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
@ -159,6 +181,7 @@ class InferenceRouter(Inference):
|
||||||
params["tool_prompt_format"] = tool_prompt_format
|
params["tool_prompt_format"] = tool_prompt_format
|
||||||
tool_config = ToolConfig(**params)
|
tool_config = ToolConfig(**params)
|
||||||
|
|
||||||
|
tool_config = copy.copy(tool_config)
|
||||||
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
|
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
|
||||||
|
|
||||||
tools = tools or []
|
tools = tools or []
|
||||||
|
@ -201,6 +224,10 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||||
|
)
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
@ -228,6 +255,7 @@ class InferenceRouter(Inference):
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: Optional[int] = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
|
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
@ -247,12 +275,15 @@ class SafetyRouter(Safety):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing SafetyRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "SafetyRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "SafetyRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
|
@ -262,6 +293,7 @@ class SafetyRouter(Safety):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
|
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
|
@ -270,6 +302,7 @@ class SafetyRouter(Safety):
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: Dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
|
||||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||||
shield_id=shield_id,
|
shield_id=shield_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -282,12 +315,15 @@ class DatasetIORouter(DatasetIO):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing DatasetIORouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "DatasetIORouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "DatasetIORouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def get_rows_paginated(
|
async def get_rows_paginated(
|
||||||
|
@ -297,6 +333,7 @@ class DatasetIORouter(DatasetIO):
|
||||||
page_token: Optional[str] = None,
|
page_token: Optional[str] = None,
|
||||||
filter_condition: Optional[str] = None,
|
filter_condition: Optional[str] = None,
|
||||||
) -> PaginatedRowsResult:
|
) -> PaginatedRowsResult:
|
||||||
|
logcat.debug("core", f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}")
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=rows_in_page,
|
rows_in_page=rows_in_page,
|
||||||
|
@ -305,6 +342,7 @@ class DatasetIORouter(DatasetIO):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
|
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows=rows,
|
rows=rows,
|
||||||
|
@ -316,12 +354,15 @@ class ScoringRouter(Scoring):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing ScoringRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "ScoringRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "ScoringRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
|
@ -330,6 +371,7 @@ class ScoringRouter(Scoring):
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
|
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
res = {}
|
res = {}
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||||
|
@ -350,6 +392,7 @@ class ScoringRouter(Scoring):
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
|
logcat.debug("core", f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||||
res = {}
|
res = {}
|
||||||
# look up and map each scoring function to its provider impl
|
# look up and map each scoring function to its provider impl
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
|
@ -367,12 +410,15 @@ class EvalRouter(Eval):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing EvalRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "EvalRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "EvalRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
|
@ -380,6 +426,7 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
task_config: BenchmarkConfig,
|
task_config: BenchmarkConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
|
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
task_config=task_config,
|
task_config=task_config,
|
||||||
|
@ -392,6 +439,7 @@ class EvalRouter(Eval):
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
task_config: BenchmarkConfig,
|
task_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
|
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
|
@ -404,6 +452,7 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> Optional[JobStatus]:
|
) -> Optional[JobStatus]:
|
||||||
|
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||||
|
|
||||||
async def job_cancel(
|
async def job_cancel(
|
||||||
|
@ -411,6 +460,7 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||||
benchmark_id,
|
benchmark_id,
|
||||||
job_id,
|
job_id,
|
||||||
|
@ -421,6 +471,7 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
|
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||||
benchmark_id,
|
benchmark_id,
|
||||||
job_id,
|
job_id,
|
||||||
|
@ -433,6 +484,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
|
@ -441,6 +493,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
vector_db_ids: List[str],
|
vector_db_ids: List[str],
|
||||||
query_config: Optional[RAGQueryConfig] = None,
|
query_config: Optional[RAGQueryConfig] = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
|
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||||
content, vector_db_ids, query_config
|
content, vector_db_ids, query_config
|
||||||
)
|
)
|
||||||
|
@ -451,6 +504,10 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
|
||||||
|
)
|
||||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||||
documents, vector_db_id, chunk_size_in_tokens
|
documents, vector_db_id, chunk_size_in_tokens
|
||||||
)
|
)
|
||||||
|
@ -459,6 +516,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing ToolRuntimeRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||||
|
@ -467,12 +525,15 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "ToolRuntimeRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "ToolRuntimeRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||||
|
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
@ -481,4 +542,5 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
|
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||||
|
|
|
@ -318,14 +318,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
)
|
)
|
||||||
model = await self.get_object_by_identifier("model", embedding_model)
|
model = await self.get_object_by_identifier("model", embedding_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
if embedding_model == "all-MiniLM-L6-v2":
|
raise ValueError(f"Model {embedding_model} not found")
|
||||||
raise ValueError(
|
|
||||||
"Embeddings are now served via Inference providers. "
|
|
||||||
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
|
|
||||||
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Model {embedding_model} not found")
|
|
||||||
if model.model_type != ModelType.embedding:
|
if model.model_type != ModelType.embedding:
|
||||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||||
if "embedding_dimension" not in model.metadata:
|
if "embedding_dimension" not in model.metadata:
|
||||||
|
|
|
@ -26,9 +26,9 @@ from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack import logcat
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
|
@ -55,7 +55,7 @@ from .endpoints import get_all_api_endpoints
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
||||||
logger = logging.getLogger(__name__)
|
logcat.init()
|
||||||
|
|
||||||
|
|
||||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||||
|
@ -142,23 +142,23 @@ def handle_signal(app, signum, _) -> None:
|
||||||
not block the current execution.
|
not block the current execution.
|
||||||
"""
|
"""
|
||||||
signame = signal.Signals(signum).name
|
signame = signal.Signals(signum).name
|
||||||
logger.info(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||||
|
|
||||||
async def shutdown():
|
async def shutdown():
|
||||||
try:
|
try:
|
||||||
# Gracefully shut down implementations
|
# Gracefully shut down implementations
|
||||||
for impl in app.__llama_stack_impls__.values():
|
for impl in app.__llama_stack_impls__.values():
|
||||||
impl_name = impl.__class__.__name__
|
impl_name = impl.__class__.__name__
|
||||||
logger.info("Shutting down %s", impl_name)
|
logcat.info("server", f"Shutting down {impl_name}")
|
||||||
try:
|
try:
|
||||||
if hasattr(impl, "shutdown"):
|
if hasattr(impl, "shutdown"):
|
||||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||||
else:
|
else:
|
||||||
logger.warning("No shutdown method for %s", impl_name)
|
logcat.warning("server", f"No shutdown method for {impl_name}")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
logcat.exception("server", f"Shutdown timeout for {impl_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
|
||||||
|
|
||||||
# Gather all running tasks
|
# Gather all running tasks
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
@ -172,7 +172,7 @@ def handle_signal(app, signum, _) -> None:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.exception("Timeout while waiting for tasks to finish")
|
logcat.exception("server", "Timeout while waiting for tasks to finish")
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
|
@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None:
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
logger.info("Starting up")
|
logcat.info("server", "Starting up")
|
||||||
yield
|
yield
|
||||||
logger.info("Shutting down")
|
logcat.info("server", "Shutting down")
|
||||||
for impl in app.__llama_stack_impls__.values():
|
for impl in app.__llama_stack_impls__.values():
|
||||||
await impl.shutdown()
|
await impl.shutdown()
|
||||||
|
|
||||||
|
@ -209,10 +209,10 @@ async def sse_generator(event_gen):
|
||||||
yield create_sse_event(item)
|
yield create_sse_event(item)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
print("Generator cancelled")
|
logcat.info("server", "Generator cancelled")
|
||||||
await event_gen.aclose()
|
await event_gen.aclose()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exception(e)
|
logcat.exception("server", "Error in sse_generator")
|
||||||
yield create_sse_event(
|
yield create_sse_event(
|
||||||
{
|
{
|
||||||
"error": {
|
"error": {
|
||||||
|
@ -234,7 +234,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
value = func(**kwargs)
|
value = func(**kwargs)
|
||||||
return await maybe_await(value)
|
return await maybe_await(value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exception(e)
|
logcat.exception("server", f"Error in {func.__name__}")
|
||||||
raise translate_exception(e) from e
|
raise translate_exception(e) from e
|
||||||
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
@ -313,6 +313,8 @@ class ClientVersionMiddleware:
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
logcat.init()
|
||||||
|
|
||||||
"""Start the LlamaStack server."""
|
"""Start the LlamaStack server."""
|
||||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -352,10 +354,10 @@ def main():
|
||||||
for env_pair in args.env:
|
for env_pair in args.env:
|
||||||
try:
|
try:
|
||||||
key, value = validate_env_pair(env_pair)
|
key, value = validate_env_pair(env_pair)
|
||||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error: {str(e)}")
|
logcat.error("server", f"Error: {str(e)}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if args.yaml_config:
|
if args.yaml_config:
|
||||||
|
@ -363,12 +365,12 @@ def main():
|
||||||
config_file = Path(args.yaml_config)
|
config_file = Path(args.yaml_config)
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise ValueError(f"Config file {config_file} does not exist")
|
raise ValueError(f"Config file {config_file} does not exist")
|
||||||
logger.info(f"Using config file: {config_file}")
|
logcat.info("server", f"Using config file: {config_file}")
|
||||||
elif args.template:
|
elif args.template:
|
||||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise ValueError(f"Template {args.template} does not exist")
|
raise ValueError(f"Template {args.template} does not exist")
|
||||||
logger.info(f"Using template {args.template} config file: {config_file}")
|
logcat.info("server", f"Using template {args.template} config file: {config_file}")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either --yaml-config or --template must be provided")
|
raise ValueError("Either --yaml-config or --template must be provided")
|
||||||
|
|
||||||
|
@ -376,9 +378,10 @@ def main():
|
||||||
config = replace_env_vars(yaml.safe_load(fp))
|
config = replace_env_vars(yaml.safe_load(fp))
|
||||||
config = StackRunConfig(**config)
|
config = StackRunConfig(**config)
|
||||||
|
|
||||||
logger.info("Run configuration:")
|
logcat.info("server", "Run configuration:")
|
||||||
safe_config = redact_sensitive_fields(config.model_dump())
|
safe_config = redact_sensitive_fields(config.model_dump())
|
||||||
logger.info(yaml.dump(safe_config, indent=2))
|
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
|
||||||
|
logcat.info("server", log_line)
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
app.add_middleware(TracingMiddleware)
|
app.add_middleware(TracingMiddleware)
|
||||||
|
@ -388,7 +391,7 @@ def main():
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
except InvalidProviderError as e:
|
except InvalidProviderError as e:
|
||||||
logger.error(f"Error: {str(e)}")
|
logcat.error("server", f"Error: {str(e)}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
|
@ -433,11 +436,8 @@ def main():
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Serving API {api_str}")
|
logcat.debug("server", f"Serving API {api_str}")
|
||||||
for endpoint in endpoints:
|
|
||||||
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
|
|
||||||
|
|
||||||
print("")
|
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
||||||
|
@ -463,10 +463,10 @@ def main():
|
||||||
"ssl_keyfile": keyfile,
|
"ssl_keyfile": keyfile,
|
||||||
"ssl_certfile": certfile,
|
"ssl_certfile": certfile,
|
||||||
}
|
}
|
||||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||||
|
|
||||||
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
||||||
logger.info(f"Listening on {listen_host}:{port}")
|
logcat.info("server", f"Listening on {listen_host}:{port}")
|
||||||
|
|
||||||
uvicorn_config = {
|
uvicorn_config = {
|
||||||
"app": app,
|
"app": app,
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
@ -13,6 +12,7 @@ from typing import Any, Dict, Optional
|
||||||
import yaml
|
import yaml
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
from llama_stack import logcat
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.batch_inference import BatchInference
|
from llama_stack.apis.batch_inference import BatchInference
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
|
@ -39,8 +39,6 @@ from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||||
from llama_stack.distribution.store.registry import create_dist_registry
|
from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaStack(
|
class LlamaStack(
|
||||||
VectorDBs,
|
VectorDBs,
|
||||||
|
@ -101,12 +99,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||||
objects_to_process = response.data if hasattr(response, "data") else response
|
objects_to_process = response.data if hasattr(response, "data") else response
|
||||||
|
|
||||||
for obj in objects_to_process:
|
for obj in objects_to_process:
|
||||||
log.info(
|
logcat.debug(
|
||||||
|
"core",
|
||||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info("")
|
|
||||||
|
|
||||||
|
|
||||||
class EnvVarError(Exception):
|
class EnvVarError(Exception):
|
||||||
def __init__(self, var_name: str, path: str = ""):
|
def __init__(self, var_name: str, path: str = ""):
|
||||||
|
@ -155,18 +152,34 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
elif isinstance(config, str):
|
elif isinstance(config, str):
|
||||||
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"
|
# Updated pattern to support both default values (:) and conditional values (+)
|
||||||
|
pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}"
|
||||||
|
|
||||||
def get_env_var(match):
|
def get_env_var(match):
|
||||||
env_var = match.group(1)
|
env_var = match.group(1)
|
||||||
default_val = match.group(2)
|
operator = match.group(2) # ':' for default, '+' for conditional
|
||||||
|
value_expr = match.group(3)
|
||||||
|
|
||||||
value = os.environ.get(env_var)
|
env_value = os.environ.get(env_var)
|
||||||
if not value:
|
|
||||||
if default_val is None:
|
if operator == ":": # Default value syntax: ${env.FOO:default}
|
||||||
raise EnvVarError(env_var, path)
|
if not env_value:
|
||||||
|
if value_expr is None:
|
||||||
|
raise EnvVarError(env_var, path)
|
||||||
|
else:
|
||||||
|
value = value_expr
|
||||||
else:
|
else:
|
||||||
value = default_val
|
value = env_value
|
||||||
|
elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set}
|
||||||
|
if env_value:
|
||||||
|
value = value_expr
|
||||||
|
else:
|
||||||
|
# If env var is not set, return empty string for the conditional case
|
||||||
|
value = ""
|
||||||
|
else: # No operator case: ${env.FOO}
|
||||||
|
if not env_value:
|
||||||
|
raise EnvVarError(env_var, path)
|
||||||
|
value = env_value
|
||||||
|
|
||||||
# expand "~" from the values
|
# expand "~" from the values
|
||||||
return os.path.expanduser(value)
|
return os.path.expanduser(value)
|
||||||
|
|
|
@ -98,15 +98,20 @@ case "$env_type" in
|
||||||
*)
|
*)
|
||||||
esac
|
esac
|
||||||
|
|
||||||
set -x
|
|
||||||
|
|
||||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||||
|
set -x
|
||||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||||
--yaml-config "$yaml_config" \
|
--yaml-config "$yaml_config" \
|
||||||
--port "$port" \
|
--port "$port" \
|
||||||
$env_vars \
|
$env_vars \
|
||||||
$other_args
|
$other_args
|
||||||
elif [[ "$env_type" == "container" ]]; then
|
elif [[ "$env_type" == "container" ]]; then
|
||||||
|
# Check if container command is available
|
||||||
|
if ! is_command_available $CONTAINER_BINARY; then
|
||||||
|
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
|
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
|
||||||
# Disable SELinux labels
|
# Disable SELinux labels
|
||||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
||||||
|
@ -136,6 +141,8 @@ elif [[ "$env_type" == "container" ]]; then
|
||||||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
set -x
|
||||||
|
|
||||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||||
-p $port:$port \
|
-p $port:$port \
|
||||||
$env_vars \
|
$env_vars \
|
||||||
|
|
|
@ -1,72 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
RED='\033[0;31m'
|
|
||||||
NC='\033[0m' # No Color
|
|
||||||
|
|
||||||
error_handler() {
|
|
||||||
echo "Error occurred in script at line: ${1}" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
trap 'error_handler ${LINENO}' ERR
|
|
||||||
|
|
||||||
if [ $# -lt 3 ]; then
|
|
||||||
echo "Usage: $0 <venv_path> <yaml_config> <port> <script_args...>"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
venv_path="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
yaml_config="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
port="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
# Initialize env_vars as an empty array
|
|
||||||
env_vars=""
|
|
||||||
other_args=""
|
|
||||||
# Process environment variables from --env arguments
|
|
||||||
while [[ $# -gt 0 ]]; do
|
|
||||||
case "$1" in
|
|
||||||
--env)
|
|
||||||
|
|
||||||
if [[ -n "$2" ]]; then
|
|
||||||
env_vars="$env_vars --env $2"
|
|
||||||
shift 2
|
|
||||||
else
|
|
||||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
other_args="$other_args $1"
|
|
||||||
shift
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "Using virtual environment: $venv_path"
|
|
||||||
# Activate virtual environment
|
|
||||||
if [ ! -d "$venv_path" ]; then
|
|
||||||
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
source "$venv_path/bin/activate"
|
|
||||||
|
|
||||||
set -x
|
|
||||||
python -m llama_stack.distribution.server.server \
|
|
||||||
--yaml-config "$yaml_config" \
|
|
||||||
--port "$port" \
|
|
||||||
$env_vars \
|
|
||||||
$other_args
|
|
|
@ -134,7 +134,7 @@ def rag_chat_page():
|
||||||
dict(
|
dict(
|
||||||
name="builtin::rag/knowledge_search",
|
name="builtin::rag/knowledge_search",
|
||||||
args={
|
args={
|
||||||
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
"vector_db_ids": list(selected_vector_dbs),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|
|
@ -46,7 +46,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
||||||
envs = conda_env_info["envs"]
|
envs = conda_env_info["envs"]
|
||||||
for envpath in envs:
|
for envpath in envs:
|
||||||
if envpath.endswith(env_name):
|
if os.path.basename(envpath) == env_name:
|
||||||
return envpath
|
return envpath
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
204
llama_stack/logcat.py
Normal file
204
llama_stack/logcat.py
Normal file
|
@ -0,0 +1,204 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Category-based logging utility for llama-stack.
|
||||||
|
|
||||||
|
This module provides a wrapper over the standard Python logging module that supports
|
||||||
|
categorized logging with environment variable control.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from llama_stack import logcat
|
||||||
|
logcat.info("server", "Starting up...")
|
||||||
|
logcat.debug("inference", "Processing request...")
|
||||||
|
|
||||||
|
Environment variable:
|
||||||
|
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
|
||||||
|
Example: "server=debug;inference=warning"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
# ANSI color codes for terminal output
|
||||||
|
COLORS = {
|
||||||
|
"RESET": "\033[0m",
|
||||||
|
"DEBUG": "\033[36m", # Cyan
|
||||||
|
"INFO": "\033[32m", # Green
|
||||||
|
"WARNING": "\033[33m", # Yellow
|
||||||
|
"ERROR": "\033[31m", # Red
|
||||||
|
"CRITICAL": "\033[35m", # Magenta
|
||||||
|
"DIM": "\033[2m", # Dimmed text
|
||||||
|
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
|
||||||
|
}
|
||||||
|
|
||||||
|
# Static list of valid categories representing various parts of the Llama Stack
|
||||||
|
# server codebase
|
||||||
|
CATEGORIES = [
|
||||||
|
"core",
|
||||||
|
"server",
|
||||||
|
"router",
|
||||||
|
"inference",
|
||||||
|
"agents",
|
||||||
|
"safety",
|
||||||
|
"eval",
|
||||||
|
"tools",
|
||||||
|
"client",
|
||||||
|
]
|
||||||
|
|
||||||
|
_logger = logging.getLogger("llama_stack")
|
||||||
|
_logger.propagate = False
|
||||||
|
|
||||||
|
_default_level = logging.INFO
|
||||||
|
|
||||||
|
# Category-level mapping (can be modified by environment variables)
|
||||||
|
_category_levels: Dict[str, int] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class TerminalStreamHandler(logging.StreamHandler):
|
||||||
|
def __init__(self, stream=None):
|
||||||
|
super().__init__(stream)
|
||||||
|
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
record.is_tty = self.is_tty
|
||||||
|
return super().format(record)
|
||||||
|
|
||||||
|
|
||||||
|
class ColoredFormatter(logging.Formatter):
|
||||||
|
"""Custom formatter with colors and fixed-width level names"""
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
levelname = record.levelname
|
||||||
|
# Use only time with milliseconds, not date
|
||||||
|
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
|
||||||
|
|
||||||
|
file_info = f"{record.filename}:{record.lineno}"
|
||||||
|
|
||||||
|
# Get category from extra if available
|
||||||
|
category = getattr(record, "category", None)
|
||||||
|
msg = record.getMessage()
|
||||||
|
|
||||||
|
if getattr(record, "is_tty", False):
|
||||||
|
color = COLORS.get(levelname, COLORS["RESET"])
|
||||||
|
if category:
|
||||||
|
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
|
||||||
|
formatted_msg = (
|
||||||
|
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
|
||||||
|
f"{file_info:<20} {category_formatted}{msg}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
formatted_msg = (
|
||||||
|
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
|
||||||
|
f"{file_info:<20} {msg}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if category:
|
||||||
|
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
|
||||||
|
else:
|
||||||
|
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
|
||||||
|
|
||||||
|
return formatted_msg
|
||||||
|
|
||||||
|
|
||||||
|
def init(default_level: int = logging.INFO) -> None:
|
||||||
|
global _default_level, _category_levels, _logger
|
||||||
|
|
||||||
|
_default_level = default_level
|
||||||
|
|
||||||
|
_logger.setLevel(logging.DEBUG)
|
||||||
|
_logger.handlers = [] # Clear existing handlers
|
||||||
|
|
||||||
|
# Add our custom handler with the colored formatter
|
||||||
|
handler = TerminalStreamHandler()
|
||||||
|
formatter = ColoredFormatter()
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
_logger.addHandler(handler)
|
||||||
|
|
||||||
|
for category in CATEGORIES:
|
||||||
|
_category_levels[category] = default_level
|
||||||
|
|
||||||
|
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||||
|
if env_config:
|
||||||
|
for pair in env_config.split(";"):
|
||||||
|
if not pair.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
category, level = pair.split("=", 1)
|
||||||
|
category = category.strip().lower()
|
||||||
|
level = level.strip().lower()
|
||||||
|
|
||||||
|
level_value = {
|
||||||
|
"debug": logging.DEBUG,
|
||||||
|
"info": logging.INFO,
|
||||||
|
"warning": logging.WARNING,
|
||||||
|
"warn": logging.WARNING,
|
||||||
|
"error": logging.ERROR,
|
||||||
|
"critical": logging.CRITICAL,
|
||||||
|
}.get(level)
|
||||||
|
|
||||||
|
if level_value is None:
|
||||||
|
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if category == "all":
|
||||||
|
for cat in CATEGORIES:
|
||||||
|
_category_levels[cat] = level_value
|
||||||
|
else:
|
||||||
|
if category in CATEGORIES:
|
||||||
|
_category_levels[category] = level_value
|
||||||
|
else:
|
||||||
|
_logger.warning(f"Unknown logging category: {category}")
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
_logger.warning(f"Invalid logging configuration: {pair}")
|
||||||
|
|
||||||
|
|
||||||
|
def _should_log(level: int, category: str) -> bool:
|
||||||
|
category = category.lower()
|
||||||
|
if category not in _category_levels:
|
||||||
|
return False
|
||||||
|
category_level = _category_levels[category]
|
||||||
|
return level >= category_level
|
||||||
|
|
||||||
|
|
||||||
|
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
if _should_log(level, category):
|
||||||
|
kwargs.setdefault("extra", {})["category"] = category.lower()
|
||||||
|
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def debug(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def info(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
_log(logging.INFO, "info", category, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def warning(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def warn(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
warning(category, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def error(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def critical(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def exception(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
if _should_log(logging.ERROR, category):
|
||||||
|
kwargs.setdefault("extra", {})["category"] = category.lower()
|
||||||
|
_logger.exception(msg, *args, stacklevel=2, **kwargs)
|
|
@ -11,16 +11,128 @@
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from io import BytesIO
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
# import all for backwards compatibility
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||||
from llama_models.datatypes import * # noqa: F403
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
# The goal is that these set of types are relevant for all Llama models.
|
||||||
|
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||||
|
# the llama3 series of models.
|
||||||
|
|
||||||
|
|
||||||
|
class Role(Enum):
|
||||||
|
system = "system"
|
||||||
|
user = "user"
|
||||||
|
assistant = "assistant"
|
||||||
|
tool = "tool"
|
||||||
|
|
||||||
|
|
||||||
|
class BuiltinTool(Enum):
|
||||||
|
brave_search = "brave_search"
|
||||||
|
wolfram_alpha = "wolfram_alpha"
|
||||||
|
photogen = "photogen"
|
||||||
|
code_interpreter = "code_interpreter"
|
||||||
|
|
||||||
|
|
||||||
|
Primitive = Union[str, int, float, bool, None]
|
||||||
|
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
call_id: str
|
||||||
|
tool_name: Union[BuiltinTool, str]
|
||||||
|
arguments: Dict[str, RecursiveType]
|
||||||
|
|
||||||
|
@field_validator("tool_name", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
return BuiltinTool(v)
|
||||||
|
except ValueError:
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ToolPromptFormat(Enum):
|
||||||
|
"""Prompt format for calling custom / zero shot tools.
|
||||||
|
|
||||||
|
:cvar json: JSON format for calling tools. It takes the form:
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function" : {
|
||||||
|
"name": "function_name",
|
||||||
|
"description": "function_description",
|
||||||
|
"parameters": {...}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
:cvar function_tag: Function tag format, pseudo-XML. This looks like:
|
||||||
|
<function=function_name>(parameters)</function>
|
||||||
|
|
||||||
|
:cvar python_list: Python list. The output is a valid Python expression that can be
|
||||||
|
evaluated to a list. Each element in the list is a function call. Example:
|
||||||
|
["function_name(param1, param2)", "function_name(param1, param2)"]
|
||||||
|
"""
|
||||||
|
|
||||||
|
json = "json"
|
||||||
|
function_tag = "function_tag"
|
||||||
|
python_list = "python_list"
|
||||||
|
|
||||||
|
|
||||||
|
class StopReason(Enum):
|
||||||
|
end_of_turn = "end_of_turn"
|
||||||
|
end_of_message = "end_of_message"
|
||||||
|
out_of_tokens = "out_of_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
class RawMediaItem(BaseModel):
|
||||||
|
type: Literal["image"] = "image"
|
||||||
|
data: bytes | BytesIO
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
@field_serializer("data")
|
||||||
|
def serialize_data(self, data: Optional[bytes], _info):
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
return base64.b64encode(data).decode("utf-8")
|
||||||
|
|
||||||
|
@field_validator("data", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_data(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
return base64.b64decode(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class RawTextItem(BaseModel):
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
|
||||||
|
|
||||||
|
RawContent = str | RawContentItem | List[RawContentItem]
|
||||||
|
|
||||||
|
|
||||||
|
class RawMessage(BaseModel):
|
||||||
|
role: Literal["user"] | Literal["system"] | Literal["tool"] | Literal["assistant"]
|
||||||
|
content: RawContent
|
||||||
|
|
||||||
|
# This is for RAG but likely should be absorbed into content
|
||||||
|
context: Optional[RawContent] = None
|
||||||
|
|
||||||
|
# These are for the output message coming from the assistant
|
||||||
|
stop_reason: Optional[StopReason] = None
|
||||||
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
register_schema(ToolCall)
|
register_schema(ToolCall)
|
||||||
|
|
||||||
|
|
||||||
|
|
10
llama_stack/scripts/install_packages.sh → llama_stack/models/llama/llama3/__init__.py
Executable file → Normal file
10
llama_stack/scripts/install_packages.sh → llama_stack/models/llama/llama3/__init__.py
Executable file → Normal file
|
@ -1,15 +1,5 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
VERSION="$1"
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
set -x
|
|
||||||
|
|
||||||
uv pip install -U --extra-index-url https://test.pypi.org/simple \
|
|
||||||
llama-stack==$VERSION llama-models==$VERSION llama-stack-client==$VERSION
|
|
82
llama_stack/models/llama/llama3/args.py
Normal file
82
llama_stack/models/llama/llama3/args.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationScheme(Enum):
|
||||||
|
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QuantizationArgs:
|
||||||
|
scheme: Optional[QuantizationScheme] = None
|
||||||
|
group_size: Optional[int] = None
|
||||||
|
spinquant: bool = False
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if k == "scheme":
|
||||||
|
setattr(self, k, QuantizationScheme(v))
|
||||||
|
else:
|
||||||
|
if hasattr(self, k):
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAArgs:
|
||||||
|
rank: int
|
||||||
|
scale: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs:
|
||||||
|
dim: int = 4096
|
||||||
|
n_layers: int = 32
|
||||||
|
n_heads: int = 32
|
||||||
|
n_kv_heads: Optional[int] = None
|
||||||
|
vocab_size: int = -1
|
||||||
|
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||||
|
ffn_dim_multiplier: Optional[float] = None
|
||||||
|
norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 500000
|
||||||
|
use_scaled_rope: bool = False
|
||||||
|
|
||||||
|
max_batch_size: int = 32
|
||||||
|
max_seq_len: int = 2048
|
||||||
|
|
||||||
|
# vision model params
|
||||||
|
vision_chunk_size: int = -1 # image resolution for image models
|
||||||
|
vision_max_num_chunks: int = 4
|
||||||
|
vision_num_cross_attention_layers: int = -1
|
||||||
|
|
||||||
|
quantization_args: Optional[QuantizationArgs] = None
|
||||||
|
lora_args: Optional[LoRAArgs] = None
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if k == "lora_args":
|
||||||
|
setattr(self, k, LoRAArgs(**v))
|
||||||
|
elif k == "quantization_args":
|
||||||
|
setattr(self, k, QuantizationArgs(**v))
|
||||||
|
else:
|
||||||
|
if hasattr(self, k):
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
if self.n_kv_heads is None:
|
||||||
|
self.n_kv_heads = self.n_heads
|
||||||
|
assert self.n_kv_heads <= self.n_heads
|
||||||
|
assert self.n_heads % self.n_kv_heads == 0
|
||||||
|
assert self.dim % self.n_heads == 0
|
282
llama_stack/models/llama/llama3/chat_format.py
Normal file
282
llama_stack/models/llama/llama3/chat_format.py
Normal file
|
@ -0,0 +1,282 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import io
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
RawContent,
|
||||||
|
RawMediaItem,
|
||||||
|
RawMessage,
|
||||||
|
RawTextItem,
|
||||||
|
Role,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
from .tool_utils import ToolUtils
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VisionInput:
|
||||||
|
mask: List[List[int]]
|
||||||
|
images: List[PIL_Image.Image]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMInput:
|
||||||
|
tokens: List[int]
|
||||||
|
vision: Optional[VisionInput] = None
|
||||||
|
|
||||||
|
|
||||||
|
def role_str(role: Role) -> str:
|
||||||
|
role_strs = {
|
||||||
|
Role.user: "user",
|
||||||
|
Role.system: "system",
|
||||||
|
Role.tool: "ipython", # special
|
||||||
|
Role.assistant: "assistant",
|
||||||
|
}
|
||||||
|
return role_strs[role]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatFormat:
|
||||||
|
possible_headers: Dict[Role, str]
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: Tokenizer):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
|
||||||
|
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
|
||||||
|
|
||||||
|
def _encode_header(self, role: str) -> List[int]:
|
||||||
|
tokens = []
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
||||||
|
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
|
||||||
|
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def encode_content(self, content: RawContent) -> LLMInput:
|
||||||
|
tokens, images = self._encode_content(content, bos=True)
|
||||||
|
return self._model_input_from_tokens_images(tokens, images)
|
||||||
|
|
||||||
|
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||||
|
tokens = []
|
||||||
|
images = []
|
||||||
|
|
||||||
|
added_bos = False
|
||||||
|
|
||||||
|
def _process(c):
|
||||||
|
nonlocal added_bos, bos
|
||||||
|
|
||||||
|
if isinstance(c, str) or isinstance(c, RawTextItem):
|
||||||
|
if isinstance(c, RawTextItem):
|
||||||
|
c = c.text
|
||||||
|
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
|
||||||
|
added_bos = True
|
||||||
|
|
||||||
|
elif isinstance(c, RawMediaItem):
|
||||||
|
bos = False if added_bos else bos
|
||||||
|
if bos:
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||||
|
added_bos = True
|
||||||
|
tokens.append(self.vision_token)
|
||||||
|
|
||||||
|
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||||
|
image = PIL_Image.open(bytes_io)
|
||||||
|
image = image.convert("RGB")
|
||||||
|
images.append(image)
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
for c in content:
|
||||||
|
_process(c)
|
||||||
|
else:
|
||||||
|
_process(content)
|
||||||
|
|
||||||
|
return tokens, images
|
||||||
|
|
||||||
|
def encode_message(
|
||||||
|
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||||
|
) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||||
|
tokens = self._encode_header(message.role)
|
||||||
|
images = []
|
||||||
|
|
||||||
|
def _process_content(c):
|
||||||
|
toks, imgs = self._encode_content(c)
|
||||||
|
tokens.extend(toks)
|
||||||
|
images.extend(imgs)
|
||||||
|
|
||||||
|
if (
|
||||||
|
message.role == "assistant"
|
||||||
|
and len(message.tool_calls) > 0
|
||||||
|
and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter
|
||||||
|
):
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])
|
||||||
|
|
||||||
|
_process_content(message.content)
|
||||||
|
|
||||||
|
if message.role == "user" and message.context is not None:
|
||||||
|
# This is RAG context; why is it here in the chat format? I don't think
|
||||||
|
# this is needed and can be moved upwards
|
||||||
|
_process_content("\n\n")
|
||||||
|
_process_content(message.context)
|
||||||
|
|
||||||
|
if message.role == "assistant":
|
||||||
|
for t in message.tool_calls:
|
||||||
|
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
|
||||||
|
_process_content(content)
|
||||||
|
|
||||||
|
eom = False
|
||||||
|
if message.role == "assistant":
|
||||||
|
eom = message.stop_reason == StopReason.end_of_message
|
||||||
|
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|eom_id|>" if eom else "<|eot_id|>"])
|
||||||
|
return tokens, images
|
||||||
|
|
||||||
|
def encode_dialog_prompt(
|
||||||
|
self,
|
||||||
|
messages: List[RawMessage],
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
) -> LLMInput:
|
||||||
|
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
|
||||||
|
tokens = []
|
||||||
|
images = []
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||||
|
for message in messages:
|
||||||
|
toks, imgs = self.encode_message(message, tool_prompt_format)
|
||||||
|
tokens.extend(toks)
|
||||||
|
images.extend(imgs)
|
||||||
|
|
||||||
|
# Add the start of an assistant message for the model to complete.
|
||||||
|
tokens.extend(self._encode_header("assistant"))
|
||||||
|
|
||||||
|
return self._model_input_from_tokens_images(tokens, images)
|
||||||
|
|
||||||
|
# TODO(this should be generic, not only for assistant messages)
|
||||||
|
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
||||||
|
content = self.tokenizer.decode(tokens)
|
||||||
|
|
||||||
|
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||||
|
|
||||||
|
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
|
||||||
|
content = content.strip(" ")
|
||||||
|
header_str = self.possible_headers[Role.assistant]
|
||||||
|
if content.startswith(header_str):
|
||||||
|
content = content[len(header_str) :]
|
||||||
|
|
||||||
|
ipython = content.startswith("<|python_tag|>")
|
||||||
|
if ipython:
|
||||||
|
content = content[len("<|python_tag|>") :]
|
||||||
|
|
||||||
|
if content.endswith("<|eot_id|>"):
|
||||||
|
content = content[: -len("<|eot_id|>")]
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
elif content.endswith("<|eom_id|>"):
|
||||||
|
content = content[: -len("<|eom_id|>")]
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
|
tool_name = None
|
||||||
|
tool_arguments = {}
|
||||||
|
|
||||||
|
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
||||||
|
if custom_tool_info is not None:
|
||||||
|
tool_name, tool_arguments = custom_tool_info
|
||||||
|
# Sometimes when agent has custom tools alongside builin tools
|
||||||
|
# Agent responds for builtin tool calls in the format of the custom tools
|
||||||
|
# This code tries to handle that case
|
||||||
|
if tool_name in BuiltinTool.__members__:
|
||||||
|
tool_name = BuiltinTool[tool_name]
|
||||||
|
tool_arguments = {
|
||||||
|
"query": list(tool_arguments.values())[0],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
||||||
|
if builtin_tool_info is not None:
|
||||||
|
tool_name, query = builtin_tool_info
|
||||||
|
tool_arguments = {
|
||||||
|
"query": query,
|
||||||
|
}
|
||||||
|
if tool_name in BuiltinTool.__members__:
|
||||||
|
tool_name = BuiltinTool[tool_name]
|
||||||
|
elif ipython:
|
||||||
|
tool_name = BuiltinTool.code_interpreter
|
||||||
|
tool_arguments = {
|
||||||
|
"code": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
if tool_name is not None and tool_arguments is not None:
|
||||||
|
call_id = str(uuid.uuid4())
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
call_id=call_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
arguments=tool_arguments,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
content = ""
|
||||||
|
|
||||||
|
return RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=content,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
|
||||||
|
vision_input = None
|
||||||
|
if len(images) > 0:
|
||||||
|
vision_input = VisionInput(
|
||||||
|
mask=create_vision_mask(tokens, self.vision_token),
|
||||||
|
images=images,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LLMInput(
|
||||||
|
tokens=[128256 if token == self.vision_token else token for token in tokens],
|
||||||
|
vision=vision_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_vision_mask(
|
||||||
|
tokens: List[int],
|
||||||
|
vision_token: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
|
||||||
|
if len(vision_token_locations) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if len(vision_token_locations) == 1:
|
||||||
|
# only one image present, unmask until end of sequence
|
||||||
|
return [[vision_token_locations[0], -1]]
|
||||||
|
vision_masks = [
|
||||||
|
[loc1, loc2] for loc1, loc2 in zip(vision_token_locations[:-1], vision_token_locations[1:], strict=False)
|
||||||
|
]
|
||||||
|
# last image will attend to all subsequent text
|
||||||
|
vision_masks.append([vision_token_locations[-1], len(tokens)])
|
||||||
|
|
||||||
|
# if there are two or more consecutive vision tokens,
|
||||||
|
# they should all attend to all subsequent
|
||||||
|
# text present
|
||||||
|
last_mask_end = vision_masks[-1][1]
|
||||||
|
for vision_mask in vision_masks[::-1]:
|
||||||
|
if vision_mask[0] == vision_mask[1] - 1:
|
||||||
|
vision_mask[1] = last_mask_end
|
||||||
|
last_mask_end = vision_mask[1]
|
||||||
|
return vision_masks
|
|
@ -14,20 +14,19 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from termcolor import colored
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from termcolor import colored
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
|
||||||
|
|
||||||
from . import template_data
|
from . import template_data
|
||||||
|
from .chat_format import ChatFormat
|
||||||
from .prompt_templates import (
|
from .prompt_templates import (
|
||||||
BuiltinToolGenerator,
|
BuiltinToolGenerator,
|
||||||
FunctionTagCustomToolGenerator,
|
FunctionTagCustomToolGenerator,
|
||||||
|
@ -35,6 +34,7 @@ from .prompt_templates import (
|
||||||
SystemDefaultGenerator,
|
SystemDefaultGenerator,
|
||||||
ToolResponseGenerator,
|
ToolResponseGenerator,
|
||||||
)
|
)
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
THIS_DIR = Path(__file__).parent
|
THIS_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
|
311
llama_stack/models/llama/llama3/model.py
Normal file
311
llama_stack/models/llama/llama3/model.py
Normal file
|
@ -0,0 +1,311 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import fairscale.nn.model_parallel.initialize as fs_init
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from fairscale.nn.model_parallel.layers import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .args import ModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
return output * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Values obtained from grid search
|
||||||
|
scale_factor = 8
|
||||||
|
low_freq_factor = 1
|
||||||
|
high_freq_factor = 4
|
||||||
|
old_context_len = 8192 # original llama3 length
|
||||||
|
|
||||||
|
low_freq_wavelen = old_context_len / low_freq_factor
|
||||||
|
high_freq_wavelen = old_context_len / high_freq_factor
|
||||||
|
|
||||||
|
wavelen = 2 * torch.pi / freqs
|
||||||
|
new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
|
||||||
|
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||||
|
return torch.where(
|
||||||
|
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
|
||||||
|
(1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
|
||||||
|
new_freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
|
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||||
|
if use_scaled:
|
||||||
|
freqs = apply_scaling(freqs)
|
||||||
|
freqs = torch.outer(t, freqs)
|
||||||
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||||
|
ndim = x.ndim
|
||||||
|
assert 0 <= 1 < ndim
|
||||||
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
return freqs_cis.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(
|
||||||
|
xq: torch.Tensor,
|
||||||
|
xk: torch.Tensor,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||||
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||||
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||||
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||||
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||||
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||||
|
bs, slen, n_kv_heads, head_dim = x.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return x
|
||||||
|
return (
|
||||||
|
x[:, :, :, None, :]
|
||||||
|
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||||
|
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||||
|
model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||||
|
self.n_local_heads = args.n_heads // model_parallel_size
|
||||||
|
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||||
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
|
self.head_dim = args.dim // args.n_heads
|
||||||
|
|
||||||
|
self.wq = ColumnParallelLinear(
|
||||||
|
args.dim,
|
||||||
|
args.n_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
self.wk = ColumnParallelLinear(
|
||||||
|
args.dim,
|
||||||
|
self.n_kv_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
self.wv = ColumnParallelLinear(
|
||||||
|
args.dim,
|
||||||
|
self.n_kv_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
self.wo = RowParallelLinear(
|
||||||
|
args.n_heads * self.head_dim,
|
||||||
|
args.dim,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cache_k = torch.zeros(
|
||||||
|
(
|
||||||
|
args.max_batch_size,
|
||||||
|
args.max_seq_len,
|
||||||
|
self.n_local_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.cache_v = torch.zeros(
|
||||||
|
(
|
||||||
|
args.max_batch_size,
|
||||||
|
args.max_seq_len,
|
||||||
|
self.n_local_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
start_pos: int,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
bsz, seqlen, _ = x.shape
|
||||||
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||||
|
|
||||||
|
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||||
|
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||||
|
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||||
|
|
||||||
|
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||||
|
|
||||||
|
self.cache_k = self.cache_k.to(xq)
|
||||||
|
self.cache_v = self.cache_v.to(xq)
|
||||||
|
|
||||||
|
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
||||||
|
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
||||||
|
|
||||||
|
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
||||||
|
values = self.cache_v[:bsz, : start_pos + seqlen]
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||||
|
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||||
|
|
||||||
|
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
||||||
|
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||||
|
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||||
|
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
||||||
|
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||||
|
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
|
||||||
|
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||||
|
return self.wo(output)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: Optional[float],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
|
# custom dim factor multiplier
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||||
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||||
|
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
|
||||||
|
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, layer_id: int, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = args.n_heads
|
||||||
|
self.dim = args.dim
|
||||||
|
self.head_dim = args.dim // args.n_heads
|
||||||
|
self.attention = Attention(args)
|
||||||
|
self.feed_forward = FeedForward(
|
||||||
|
dim=args.dim,
|
||||||
|
hidden_dim=4 * args.dim,
|
||||||
|
multiple_of=args.multiple_of,
|
||||||
|
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||||
|
)
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
start_pos: int,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||||
|
out = h + self.feed_forward(self.ffn_norm(h))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, params: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.params = params
|
||||||
|
self.vocab_size = params.vocab_size
|
||||||
|
self.n_layers = params.n_layers
|
||||||
|
|
||||||
|
self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
|
||||||
|
|
||||||
|
self.layers = torch.nn.ModuleList()
|
||||||
|
for layer_id in range(params.n_layers):
|
||||||
|
self.layers.append(TransformerBlock(layer_id, params))
|
||||||
|
|
||||||
|
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||||
|
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
|
||||||
|
|
||||||
|
self.freqs_cis = precompute_freqs_cis(
|
||||||
|
params.dim // params.n_heads,
|
||||||
|
params.max_seq_len * 2,
|
||||||
|
params.rope_theta,
|
||||||
|
params.use_scaled_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward(self, tokens: torch.Tensor, start_pos: int):
|
||||||
|
_bsz, seqlen = tokens.shape
|
||||||
|
h = self.tok_embeddings(tokens)
|
||||||
|
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||||
|
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if seqlen > 1:
|
||||||
|
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
||||||
|
|
||||||
|
mask = torch.triu(mask, diagonal=1)
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/issues/100005
|
||||||
|
# torch.triu is buggy when the device is mps: filled values are
|
||||||
|
# nan instead of 0.
|
||||||
|
if mask.device.type == torch.device("mps").type:
|
||||||
|
mask = torch.nan_to_num(mask, nan=0.0)
|
||||||
|
|
||||||
|
# When performing key-value caching, we compute the attention scores
|
||||||
|
# only for the new sequence. Thus, the matrix of scores is of size
|
||||||
|
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
||||||
|
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
||||||
|
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
h = layer(h, start_pos, freqs_cis, mask)
|
||||||
|
h = self.norm(h)
|
||||||
|
output = self.output(h).float()
|
||||||
|
return output
|
12
llama_stack/models/llama/llama3/multimodal/__init__.py
Normal file
12
llama_stack/models/llama/llama3/multimodal/__init__.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
179
llama_stack/models/llama/llama3/multimodal/encoder_utils.py
Normal file
179
llama_stack/models/llama/llama3/multimodal/encoder_utils.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||||
|
import math
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .utils import get_negative_inf_value, to_2tuple
|
||||||
|
|
||||||
|
logger = getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
||||||
|
"""
|
||||||
|
Resize position embedding for vision encoder.
|
||||||
|
Original position embedding is [n_tiles * n_tiles + 1, dim]
|
||||||
|
New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
|
||||||
|
"""
|
||||||
|
new_grid_size = to_2tuple(grid_size)
|
||||||
|
orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
|
||||||
|
|
||||||
|
new_pos_emb_tok, new_pos_emb_img = (
|
||||||
|
orig_pos_embed[:1],
|
||||||
|
orig_pos_embed[1:],
|
||||||
|
)
|
||||||
|
logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
|
||||||
|
|
||||||
|
new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
new_pos_emb_img = F.interpolate(
|
||||||
|
new_pos_emb_img,
|
||||||
|
size=new_grid_size,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
|
||||||
|
new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
|
||||||
|
return new_pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
||||||
|
"""
|
||||||
|
Takes a local position embedding for vision encoder and uses it
|
||||||
|
to initialize the global position embedding.
|
||||||
|
Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
|
||||||
|
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
||||||
|
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
||||||
|
"""
|
||||||
|
pos_embed = pos_and_cls_embed[1:]
|
||||||
|
cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
|
||||||
|
grid_size = to_2tuple(grid_size)
|
||||||
|
new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
|
||||||
|
new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
|
||||||
|
new_pos_emb_img = F.interpolate(
|
||||||
|
new_pos_emb_img,
|
||||||
|
size=new_grid_size,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
|
||||||
|
new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
|
||||||
|
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
|
||||||
|
cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
|
||||||
|
pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
|
||||||
|
return pos_and_cls_embed
|
||||||
|
|
||||||
|
|
||||||
|
def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
||||||
|
"""
|
||||||
|
Takes a global position embedding for vision encoder and resizes it to new size.
|
||||||
|
Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
|
||||||
|
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
||||||
|
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
||||||
|
"""
|
||||||
|
# first remove cls token
|
||||||
|
pos_embed = pos_and_cls_embed[:, :, 1:]
|
||||||
|
cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
|
||||||
|
|
||||||
|
xs_old, ys_old, ntok, dim = pos_embed.shape
|
||||||
|
old_grid_size = int(math.sqrt(ntok))
|
||||||
|
|
||||||
|
# move to correct form for interpolation
|
||||||
|
pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
|
||||||
|
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
|
||||||
|
pos_embed = pos_embed.unsqueeze(0)
|
||||||
|
|
||||||
|
# interpolate
|
||||||
|
new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
|
||||||
|
pos_embed = pos_embed.permute(0, 3, 1, 2)
|
||||||
|
pos_embed_resized = F.interpolate(
|
||||||
|
pos_embed,
|
||||||
|
size=new_size,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
|
||||||
|
|
||||||
|
# move it back in place
|
||||||
|
pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
|
||||||
|
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
|
||||||
|
|
||||||
|
# interpolate cls token
|
||||||
|
cls_embed = cls_embed.permute(2, 3, 0, 1)
|
||||||
|
cls_embed_resized = F.interpolate(
|
||||||
|
cls_embed,
|
||||||
|
size=(x_scale, y_scale),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
|
||||||
|
# add cls token back in
|
||||||
|
pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
|
||||||
|
|
||||||
|
return pos_and_cls_embed
|
||||||
|
|
||||||
|
|
||||||
|
def build_encoder_attention_mask(
|
||||||
|
x: torch.Tensor,
|
||||||
|
ar: torch.Tensor,
|
||||||
|
ntok: int,
|
||||||
|
num_chunks: int,
|
||||||
|
n_heads: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Build vision encoder attention mask that omits padding tokens.
|
||||||
|
"""
|
||||||
|
masks = []
|
||||||
|
for arx in ar:
|
||||||
|
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
|
||||||
|
mask_i[: arx[0] * arx[1], :ntok] = 0
|
||||||
|
mask_i = mask_i.view(num_chunks * x.shape[2], -1)
|
||||||
|
mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
|
||||||
|
mask_i = mask_i.unsqueeze(0)
|
||||||
|
masks.append(mask_i)
|
||||||
|
masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1)
|
||||||
|
return masks
|
||||||
|
|
||||||
|
|
||||||
|
def expand_num_tokens_to_mult8(x):
|
||||||
|
num_pad_tokens = 8 - (x.shape[-2] % 8)
|
||||||
|
if num_pad_tokens == 0:
|
||||||
|
return x, 0
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
x,
|
||||||
|
torch.zeros(
|
||||||
|
(x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=-2,
|
||||||
|
),
|
||||||
|
num_pad_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def contract_num_tokens_from_mult8(x, num_pad_tokens):
|
||||||
|
if num_pad_tokens == 0:
|
||||||
|
return x
|
||||||
|
return x[:, :, :-num_pad_tokens]
|
408
llama_stack/models/llama/llama3/multimodal/image_transform.py
Normal file
408
llama_stack/models/llama/llama3/multimodal/image_transform.py
Normal file
|
@ -0,0 +1,408 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections import defaultdict
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Any, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as tv
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
IMAGE_RES = 224
|
||||||
|
|
||||||
|
logger = getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class VariableSizeImageTransform(object):
|
||||||
|
"""
|
||||||
|
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||||
|
based on the image aspect ratio and the number of image chunks we allow.
|
||||||
|
|
||||||
|
The algorithm will NOT distort the image fit a certain aspect ratio, because
|
||||||
|
that leads to a significant degradation in image quality.
|
||||||
|
|
||||||
|
It can be summarized in 6 steps:
|
||||||
|
1. Find all possible canvas combinations of max_num_chunks;
|
||||||
|
2. Find the best canvas to fit the image;
|
||||||
|
3. Resize without distortion
|
||||||
|
4. Pad
|
||||||
|
5. Normalize
|
||||||
|
6. Chunk
|
||||||
|
|
||||||
|
For example, if an input image is of size 300x800, patch_size of 224,
|
||||||
|
and max_num_chunks = 8, it will find the closest aspect ratio that
|
||||||
|
is allowed within 8 image chunks, with some restrictions.
|
||||||
|
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
|
||||||
|
giving a total of 8 chunks.
|
||||||
|
|
||||||
|
If resize_to_max_canvas, the image will be resized (without distortion),
|
||||||
|
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
|
||||||
|
where we maintain the original aspect ratio and pad with zeros value for the rest.
|
||||||
|
This approach minimizes the amount of padding required for any arbitrary resolution.
|
||||||
|
|
||||||
|
However, if limit_upscaling_to_patch_size is set to True,
|
||||||
|
the upscaling will be limited to the patch size. In the example above,
|
||||||
|
the image would remain 300x800 (no upscaling), and then padded to 448:896.
|
||||||
|
|
||||||
|
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
|
||||||
|
patches are coming from the resizing and chunking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, size: int = IMAGE_RES) -> None:
|
||||||
|
self.size = size
|
||||||
|
logger.info(f"VariableSizeImageTransform size: {self.size}")
|
||||||
|
self.to_tensor = tv.ToTensor()
|
||||||
|
self._mean = (0.48145466, 0.4578275, 0.40821073)
|
||||||
|
self._std = (0.26862954, 0.26130258, 0.27577711)
|
||||||
|
self.normalize = tv.Normalize(
|
||||||
|
mean=self._mean,
|
||||||
|
std=self._std,
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
self.resample = tv.InterpolationMode.BILINEAR
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_factors(n: int) -> Set[int]:
|
||||||
|
"""
|
||||||
|
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||||
|
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): The number to find factors for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
set: A set containing all factors of the number.
|
||||||
|
"""
|
||||||
|
factors_set = set()
|
||||||
|
|
||||||
|
for i in range(1, int(n**0.5) + 1):
|
||||||
|
if n % i == 0:
|
||||||
|
factors_set.add(i)
|
||||||
|
factors_set.add(n // i)
|
||||||
|
return factors_set
|
||||||
|
|
||||||
|
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes all of the allowed resoltuions for a fixed number of chunks
|
||||||
|
and patch_size. Useful for when dividing an image into chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_num_chunks (int): Maximum number of chunks for processing.
|
||||||
|
patch_size (int): Size of the side of the patch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: List of possible resolutions as tuples (height, width).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> max_num_chunks = 5
|
||||||
|
>>> patch_size = 224
|
||||||
|
>>> find_supported_resolutions(max_num_chunks, patch_size)
|
||||||
|
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
|
||||||
|
(672, 224), (224, 448), (448, 224)])
|
||||||
|
|
||||||
|
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
|
||||||
|
{
|
||||||
|
0.25: [(1, 4)],
|
||||||
|
1.0: [(2, 2), (1, 1)],
|
||||||
|
4.0: [(4, 1)],
|
||||||
|
0.33: [(1, 3)],
|
||||||
|
3.0: [(3, 1)],
|
||||||
|
0.5: [(1, 2)],
|
||||||
|
2.0: [(2, 1)]
|
||||||
|
}
|
||||||
|
|
||||||
|
and return the resolutions multiplied by the patch_size:
|
||||||
|
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
|
||||||
|
"""
|
||||||
|
asp_dict = defaultdict(list)
|
||||||
|
for chunk_size in range(max_num_chunks, 0, -1):
|
||||||
|
_factors = sorted(self.get_factors(chunk_size))
|
||||||
|
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
|
||||||
|
for height, width in _asp_ratios:
|
||||||
|
ratio_float = height / width
|
||||||
|
asp_dict[ratio_float].append((height, width))
|
||||||
|
|
||||||
|
# get the resolutions multiplied by the patch_size
|
||||||
|
possible_resolutions = []
|
||||||
|
for value in asp_dict.values():
|
||||||
|
for height, depth in value:
|
||||||
|
possible_resolutions.append((height * patch_size, depth * patch_size))
|
||||||
|
|
||||||
|
return possible_resolutions
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_max_res_without_distortion(
|
||||||
|
image_size: Tuple[int, int],
|
||||||
|
target_size: Tuple[int, int],
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||||
|
aspect ratio, based on the target resolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size (Tuple[int, int]): The original resolution of the image (height, width).
|
||||||
|
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
|
||||||
|
Example:
|
||||||
|
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
|
||||||
|
(134, 200)
|
||||||
|
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
|
||||||
|
(450, 338)
|
||||||
|
"""
|
||||||
|
|
||||||
|
original_width, original_height = image_size
|
||||||
|
target_width, target_height = target_size
|
||||||
|
|
||||||
|
scale_w = target_width / original_width
|
||||||
|
scale_h = target_height / original_height
|
||||||
|
|
||||||
|
if scale_w < scale_h:
|
||||||
|
new_width = target_width
|
||||||
|
new_height = min(math.floor(original_height * scale_w), target_height)
|
||||||
|
else:
|
||||||
|
new_height = target_height
|
||||||
|
new_width = min(math.floor(original_width * scale_h), target_width)
|
||||||
|
|
||||||
|
return new_width, new_height
|
||||||
|
|
||||||
|
def _pad(self, image: Image.Image, target_size) -> Image.Image:
|
||||||
|
new_width, new_height = target_size
|
||||||
|
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
|
||||||
|
new_im.paste(image)
|
||||||
|
return new_im
|
||||||
|
|
||||||
|
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
|
||||||
|
# Split image into number of required tiles (width x height)
|
||||||
|
num_channels, height, width = image.size()
|
||||||
|
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
|
||||||
|
# Permute dimensions to reorder the axes
|
||||||
|
image = image.permute(1, 3, 0, 2, 4).contiguous()
|
||||||
|
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
||||||
|
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def resize_without_distortion(
|
||||||
|
self,
|
||||||
|
image: torch.Tensor,
|
||||||
|
target_size: Tuple[int, int],
|
||||||
|
max_upscaling_size: Optional[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Used to resize an image to target_resolution, without distortion.
|
||||||
|
|
||||||
|
If target_size requires upscaling the image, the user can set max_upscaling_size to
|
||||||
|
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
|
||||||
|
modifying target_size works as a boundary for the image's largest side.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resample (str): Resampling method used when resizing images.
|
||||||
|
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
|
||||||
|
max_upscaling_size (int): The maximum size to upscale the image to.
|
||||||
|
If None, there is no limit.
|
||||||
|
Examples:
|
||||||
|
>>> target_size = (1000, 1200)
|
||||||
|
>>> max_upscaling_size = 600
|
||||||
|
>>> image_size = (400, 200)
|
||||||
|
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||||
|
(600, 300) # new_size_without_distortion
|
||||||
|
|
||||||
|
>>> target_size = (1000, 1200)
|
||||||
|
>>> max_upscaling_size = 600
|
||||||
|
>>> image_size = (2000, 200)
|
||||||
|
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||||
|
(1000, 100) # new_size_without_distortion
|
||||||
|
|
||||||
|
>>> target_size = (1000, 1200)
|
||||||
|
>>> max_upscaling_size = 2000
|
||||||
|
>>> image_size = (400, 200)
|
||||||
|
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||||
|
(1000, 500) # new_size_without_distortion
|
||||||
|
|
||||||
|
>>> target_size = (1000, 1200)
|
||||||
|
>>> max_upscaling_size = None
|
||||||
|
>>> image_size = (400, 200)
|
||||||
|
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||||
|
(1000, 500) # new_size_without_distortion
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_width, image_height = image.size
|
||||||
|
image_size = (image_width, image_height)
|
||||||
|
|
||||||
|
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
|
||||||
|
if max_upscaling_size is not None:
|
||||||
|
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
|
||||||
|
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
|
||||||
|
target_size = (new_target_width, new_target_height)
|
||||||
|
|
||||||
|
# resize to target_size while preserving aspect ratio
|
||||||
|
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
|
||||||
|
|
||||||
|
image = F.resize(
|
||||||
|
image,
|
||||||
|
(new_size_without_distortion[1], new_size_without_distortion[0]),
|
||||||
|
interpolation=self.resample,
|
||||||
|
)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def get_best_fit(
|
||||||
|
self,
|
||||||
|
image_size: Tuple[int, int],
|
||||||
|
possible_resolutions: torch.Tensor,
|
||||||
|
resize_to_max_canvas: bool = False,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||||
|
resize an image to.
|
||||||
|
|
||||||
|
For each possible resolution, calculates the scaling factors for
|
||||||
|
width and height, and selects the smallest one, which is the limiting side.
|
||||||
|
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||||
|
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||||
|
|
||||||
|
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||||
|
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
|
||||||
|
|
||||||
|
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||||
|
reduce downscaling as much as possible.
|
||||||
|
|
||||||
|
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||||
|
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||||
|
has more padding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
|
||||||
|
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
||||||
|
row represents a possible resolution (height, width).
|
||||||
|
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: The best resolution [height, width] for the given image.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> image_size = (200, 300)
|
||||||
|
>>> possible_resolutions = torch.tensor([[224, 672],
|
||||||
|
... [672, 224],
|
||||||
|
... [224, 448],
|
||||||
|
... [448, 224],
|
||||||
|
... [224, 224]])
|
||||||
|
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
|
||||||
|
[224, 448]
|
||||||
|
|
||||||
|
We have:
|
||||||
|
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
||||||
|
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
||||||
|
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
||||||
|
Only one of the scales > 1:
|
||||||
|
upscaling_possible = tensor([1.1200, 1.1200])
|
||||||
|
smallest_rescale = tensor(1.1200)
|
||||||
|
So we pick the resolution with the smallest smallest area:
|
||||||
|
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
|
||||||
|
optimal_canvas = tensor([224, 448])
|
||||||
|
"""
|
||||||
|
|
||||||
|
original_width, original_height = image_size
|
||||||
|
|
||||||
|
# get all possible resolutions heights/widths
|
||||||
|
target_widths, target_heights = (
|
||||||
|
possible_resolutions[:, 0],
|
||||||
|
possible_resolutions[:, 1],
|
||||||
|
)
|
||||||
|
|
||||||
|
# get scaling factors to resize the image without distortion
|
||||||
|
scale_w = target_widths / original_width
|
||||||
|
scale_h = target_heights / original_height
|
||||||
|
|
||||||
|
# get the min scale between width and height (limiting side -> no distortion)
|
||||||
|
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
||||||
|
|
||||||
|
# filter only scales that allow upscaling
|
||||||
|
upscaling_options = scales[scales >= 1]
|
||||||
|
if len(upscaling_options) > 0:
|
||||||
|
if resize_to_max_canvas:
|
||||||
|
selected_scale = torch.max(upscaling_options)
|
||||||
|
else:
|
||||||
|
selected_scale = torch.min(upscaling_options)
|
||||||
|
else:
|
||||||
|
# no upscaling possible,
|
||||||
|
# get the minimum downscaling (max scale for scales<1)
|
||||||
|
downscaling_options = scales[scales < 1]
|
||||||
|
selected_scale = torch.max(downscaling_options)
|
||||||
|
|
||||||
|
# get all resolutions that support this scaling factor,
|
||||||
|
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
||||||
|
chosen_canvas = possible_resolutions[scales == selected_scale]
|
||||||
|
|
||||||
|
# if there are multiple resolutions,
|
||||||
|
# get the one with minimum area to reduce padding
|
||||||
|
if len(chosen_canvas) > 1:
|
||||||
|
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
||||||
|
optimal_idx = torch.argmin(areas)
|
||||||
|
optimal_canvas = chosen_canvas[optimal_idx]
|
||||||
|
else:
|
||||||
|
optimal_canvas = chosen_canvas[0]
|
||||||
|
|
||||||
|
return tuple(optimal_canvas.tolist())
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
max_num_chunks: int,
|
||||||
|
normalize_img: bool = True,
|
||||||
|
resize_to_max_canvas: bool = False,
|
||||||
|
) -> Tuple[Any, Any]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
image (PIL.Image): Image to be resized.
|
||||||
|
max_num_chunks (int): Maximum number of chunks to split the image into.
|
||||||
|
normalize_img (bool): Whether to normalize the image.
|
||||||
|
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
|
||||||
|
If True, picks the canvas the allows the largest resizing without distortion.
|
||||||
|
If False, downsample as little as possible, including no resizing at all,
|
||||||
|
but never upsample, unless the image is smaller than the patch size.
|
||||||
|
"""
|
||||||
|
assert max_num_chunks > 0
|
||||||
|
assert isinstance(image, Image.Image), type(image)
|
||||||
|
w, h = image.size
|
||||||
|
|
||||||
|
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
|
||||||
|
possible_resolutions = torch.tensor(possible_resolutions)
|
||||||
|
|
||||||
|
best_resolution = self.get_best_fit(
|
||||||
|
image_size=(w, h),
|
||||||
|
possible_resolutions=possible_resolutions,
|
||||||
|
resize_to_max_canvas=resize_to_max_canvas,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_upscaling_size = None if resize_to_max_canvas else self.size
|
||||||
|
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
|
||||||
|
image = self._pad(image, best_resolution)
|
||||||
|
|
||||||
|
image = self.to_tensor(image)
|
||||||
|
|
||||||
|
if normalize_img:
|
||||||
|
image = self.normalize(image)
|
||||||
|
|
||||||
|
ratio_w, ratio_h = (
|
||||||
|
best_resolution[0] // self.size,
|
||||||
|
best_resolution[1] // self.size,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = self._split(image, ratio_w, ratio_h) # type: ignore
|
||||||
|
|
||||||
|
ar = (ratio_h, ratio_w)
|
||||||
|
return image, ar
|
1435
llama_stack/models/llama/llama3/multimodal/model.py
Normal file
1435
llama_stack/models/llama/llama3/multimodal/model.py
Normal file
File diff suppressed because it is too large
Load diff
26
llama_stack/models/llama/llama3/multimodal/utils.py
Normal file
26
llama_stack/models/llama/llama3/multimodal/utils.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_negative_inf_value(dtype):
|
||||||
|
return torch.finfo(dtype).min
|
||||||
|
|
||||||
|
|
||||||
|
def to_2tuple(x):
|
||||||
|
if isinstance(x, collections.abc.Iterable):
|
||||||
|
return x
|
||||||
|
return (x, x)
|
|
@ -15,11 +15,8 @@ import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
|
||||||
BuiltinTool,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolParamDefinition,
|
ToolParamDefinition,
|
||||||
)
|
)
|
||||||
|
@ -226,10 +223,9 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||||
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
DEFAULT_PROMPT = textwrap.dedent(
|
DEFAULT_PROMPT = textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
|
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
|
||||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
|
||||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
|
||||||
also point it out. You should only return the function call in tools call sections.
|
|
||||||
|
|
||||||
{{ function_description }}
|
{{ function_description }}
|
||||||
""".strip("\n")
|
""".strip("\n")
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
|
128000
llama_stack/models/llama/llama3/tokenizer.model
Normal file
128000
llama_stack/models/llama/llama3/tokenizer.model
Normal file
File diff suppressed because it is too large
Load diff
214
llama_stack/models/llama/llama3/tokenizer.py
Normal file
214
llama_stack/models/llama/llama3/tokenizer.py
Normal file
|
@ -0,0 +1,214 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from logging import getLogger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import (
|
||||||
|
AbstractSet,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from tiktoken.load import load_tiktoken_bpe
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# The tiktoken tokenizer can handle <=400k chars without
|
||||||
|
# pyo3_runtime.PanicException.
|
||||||
|
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||||
|
|
||||||
|
# https://github.com/openai/tiktoken/issues/195
|
||||||
|
# Here we iterate over subsequences and split if we exceed the limit
|
||||||
|
# of max consecutive non-whitespace or whitespace characters.
|
||||||
|
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||||
|
|
||||||
|
|
||||||
|
_INSTANCE = None
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
"""
|
||||||
|
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
special_tokens: Dict[str, int]
|
||||||
|
|
||||||
|
num_reserved_special_tokens = 256
|
||||||
|
|
||||||
|
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
global _INSTANCE
|
||||||
|
|
||||||
|
if _INSTANCE is None:
|
||||||
|
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
||||||
|
return _INSTANCE
|
||||||
|
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
"""
|
||||||
|
Initializes the Tokenizer with a Tiktoken model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): The path to the Tiktoken model file.
|
||||||
|
"""
|
||||||
|
assert os.path.isfile(model_path), model_path
|
||||||
|
|
||||||
|
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||||
|
num_base_tokens = len(mergeable_ranks)
|
||||||
|
special_tokens = [
|
||||||
|
"<|begin_of_text|>",
|
||||||
|
"<|end_of_text|>",
|
||||||
|
"<|reserved_special_token_0|>",
|
||||||
|
"<|reserved_special_token_1|>",
|
||||||
|
"<|finetune_right_pad_id|>",
|
||||||
|
"<|step_id|>",
|
||||||
|
"<|start_header_id|>",
|
||||||
|
"<|end_header_id|>",
|
||||||
|
"<|eom_id|>", # end of message
|
||||||
|
"<|eot_id|>", # end of turn
|
||||||
|
"<|python_tag|>",
|
||||||
|
"<|image|>",
|
||||||
|
]
|
||||||
|
reserved_tokens = [
|
||||||
|
f"<|reserved_special_token_{2 + i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
|
||||||
|
]
|
||||||
|
special_tokens = special_tokens + reserved_tokens
|
||||||
|
|
||||||
|
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||||
|
self.model = tiktoken.Encoding(
|
||||||
|
name=Path(model_path).name,
|
||||||
|
pat_str=self.pat_str,
|
||||||
|
mergeable_ranks=mergeable_ranks,
|
||||||
|
special_tokens=self.special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.n_words: int = num_base_tokens + len(special_tokens)
|
||||||
|
# BOS / EOS token IDs
|
||||||
|
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
||||||
|
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
||||||
|
self.eot_id: int = self.special_tokens["<|eot_id|>"]
|
||||||
|
self.eom_id: int = self.special_tokens["<|eom_id|>"]
|
||||||
|
self.python_tag_id = self.special_tokens["<|python_tag|>"]
|
||||||
|
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
|
||||||
|
self.stop_tokens = [
|
||||||
|
self.eos_id,
|
||||||
|
self.special_tokens["<|eom_id|>"],
|
||||||
|
self.special_tokens["<|eot_id|>"],
|
||||||
|
]
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
s: str,
|
||||||
|
*,
|
||||||
|
bos: bool,
|
||||||
|
eos: bool,
|
||||||
|
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
||||||
|
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Encodes a string into a list of token IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s (str): The input string to be encoded.
|
||||||
|
bos (bool): Whether to prepend the beginning-of-sequence token.
|
||||||
|
eos (bool): Whether to append the end-of-sequence token.
|
||||||
|
allowed_special ("all"|set[str]): allowed special tokens in string
|
||||||
|
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[int]: A list of token IDs.
|
||||||
|
|
||||||
|
By default, setting disallowed_special=() encodes a string by ignoring
|
||||||
|
special tokens. Specifically:
|
||||||
|
- Setting `disallowed_special` to () will cause all text corresponding
|
||||||
|
to special tokens to be encoded as natural text (insteading of raising
|
||||||
|
an error).
|
||||||
|
- Setting `allowed_special` to "all" will treat all text corresponding
|
||||||
|
to special tokens to be encoded as special tokens.
|
||||||
|
"""
|
||||||
|
if allowed_special is None:
|
||||||
|
allowed_special = set()
|
||||||
|
assert type(s) is str
|
||||||
|
|
||||||
|
substrs = (
|
||||||
|
substr
|
||||||
|
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||||
|
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||||
|
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||||
|
)
|
||||||
|
)
|
||||||
|
t: List[int] = []
|
||||||
|
for substr in substrs:
|
||||||
|
t.extend(
|
||||||
|
self.model.encode(
|
||||||
|
substr,
|
||||||
|
allowed_special=allowed_special,
|
||||||
|
disallowed_special=disallowed_special,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if bos:
|
||||||
|
t.insert(0, self.bos_id)
|
||||||
|
if eos:
|
||||||
|
t.append(self.eos_id)
|
||||||
|
return t
|
||||||
|
|
||||||
|
def decode(self, t: Sequence[int]) -> str:
|
||||||
|
"""
|
||||||
|
Decodes a list of token IDs into a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (List[int]): The list of token IDs to be decoded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The decoded string.
|
||||||
|
"""
|
||||||
|
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||||
|
return self.model.decode(cast(List[int], t))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||||
|
"""
|
||||||
|
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||||
|
consecutive whitespaces or consecutive non-whitespaces.
|
||||||
|
"""
|
||||||
|
current_slice_len = 0
|
||||||
|
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||||
|
slice_start = 0
|
||||||
|
|
||||||
|
for i in range(len(s)):
|
||||||
|
is_now_space = s[i].isspace()
|
||||||
|
|
||||||
|
if current_slice_is_space ^ is_now_space:
|
||||||
|
current_slice_len = 1
|
||||||
|
current_slice_is_space = is_now_space
|
||||||
|
else:
|
||||||
|
current_slice_len += 1
|
||||||
|
if current_slice_len > max_consecutive_slice_len:
|
||||||
|
yield s[slice_start:i]
|
||||||
|
slice_start = i
|
||||||
|
current_slice_len = 1
|
||||||
|
yield s[slice_start:]
|
199
llama_stack/models/llama/llama3/tool_utils.py
Normal file
199
llama_stack/models/llama/llama3/tool_utils.py
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
|
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||||
|
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||||
|
|
||||||
|
|
||||||
|
def is_json(s):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(s)
|
||||||
|
# Return True for valid objects and not for ints, strings, etc
|
||||||
|
return isinstance(parsed, dict)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_python_list(input_string):
|
||||||
|
"""Check if the input string is a valid Python list of function calls"""
|
||||||
|
try:
|
||||||
|
# Try to parse the string
|
||||||
|
tree = ast.parse(input_string)
|
||||||
|
|
||||||
|
# Check if it's a single expression
|
||||||
|
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the expression is a list
|
||||||
|
expr = tree.body[0].value
|
||||||
|
if not isinstance(expr, ast.List):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the list is empty
|
||||||
|
if len(expr.elts) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if all elements in the list are function calls
|
||||||
|
for element in expr.elts:
|
||||||
|
if not isinstance(element, ast.Call):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the function call has a valid name
|
||||||
|
if not isinstance(element.func, ast.Name):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if all arguments are keyword arguments
|
||||||
|
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except SyntaxError:
|
||||||
|
# If parsing fails, it's not a valid Python expression
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_python_list_for_function_calls(input_string):
|
||||||
|
"""
|
||||||
|
Parse a Python list of function calls and
|
||||||
|
return a list of tuples containing the function name and arguments
|
||||||
|
"""
|
||||||
|
# Parse the string into an AST
|
||||||
|
tree = ast.parse(input_string)
|
||||||
|
|
||||||
|
# Ensure the input is a list
|
||||||
|
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
|
||||||
|
raise ValueError("Input must be a list of function calls")
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# Iterate through each function call in the list
|
||||||
|
for node in tree.body[0].value.elts:
|
||||||
|
if isinstance(node, ast.Call):
|
||||||
|
function_name = node.func.id
|
||||||
|
function_args = {}
|
||||||
|
|
||||||
|
# Extract keyword arguments
|
||||||
|
for keyword in node.keywords:
|
||||||
|
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||||
|
|
||||||
|
result.append((function_name, function_args))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ToolUtils:
|
||||||
|
@staticmethod
|
||||||
|
def is_builtin_tool_call(message_body: str) -> bool:
|
||||||
|
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
|
||||||
|
return match is not None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||||
|
# Find the first match in the text
|
||||||
|
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
||||||
|
|
||||||
|
# Check if a match is found and return it
|
||||||
|
if match:
|
||||||
|
tool_name = match.group("tool_name")
|
||||||
|
query = match.group("query")
|
||||||
|
return tool_name, query
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||||
|
# NOTE: Custom function too calls are still experimental
|
||||||
|
# Sometimes, response is of the form
|
||||||
|
# {"type": "function", "name": "function_name", "parameters": {...}
|
||||||
|
# and some times
|
||||||
|
# <function=function_name>(parameters)</function>
|
||||||
|
|
||||||
|
# Find the first match in the text
|
||||||
|
match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
|
||||||
|
if match:
|
||||||
|
tool_name = match.group("function_name")
|
||||||
|
query = match.group("args")
|
||||||
|
try:
|
||||||
|
return tool_name, json.loads(query.replace("'", '"'))
|
||||||
|
except Exception as e:
|
||||||
|
print("Exception while parsing json query for custom tool call", query, e)
|
||||||
|
return None
|
||||||
|
elif is_json(message_body):
|
||||||
|
response = json.loads(message_body)
|
||||||
|
if ("type" in response and response["type"] == "function") or ("name" in response):
|
||||||
|
function_name = response["name"]
|
||||||
|
args = response["parameters"]
|
||||||
|
return function_name, args
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
elif is_valid_python_list(message_body):
|
||||||
|
res = parse_python_list_for_function_calls(message_body)
|
||||||
|
# FIXME: Enable multiple tool calls
|
||||||
|
return res[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
||||||
|
if t.tool_name == BuiltinTool.brave_search:
|
||||||
|
q = t.arguments["query"]
|
||||||
|
return f'brave_search.call(query="{q}")'
|
||||||
|
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
||||||
|
q = t.arguments["query"]
|
||||||
|
return f'wolfram_alpha.call(query="{q}")'
|
||||||
|
elif t.tool_name == BuiltinTool.photogen:
|
||||||
|
q = t.arguments["query"]
|
||||||
|
return f'photogen.call(query="{q}")'
|
||||||
|
elif t.tool_name == BuiltinTool.code_interpreter:
|
||||||
|
return t.arguments["code"]
|
||||||
|
else:
|
||||||
|
fname = t.tool_name
|
||||||
|
|
||||||
|
if tool_prompt_format == ToolPromptFormat.json:
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": fname,
|
||||||
|
"parameters": t.arguments,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
||||||
|
args = json.dumps(t.arguments)
|
||||||
|
return f"<function={fname}>{args}</function>"
|
||||||
|
|
||||||
|
elif tool_prompt_format == ToolPromptFormat.python_list:
|
||||||
|
|
||||||
|
def format_value(value: RecursiveType) -> str:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return f'"{value}"'
|
||||||
|
elif isinstance(value, (int, float, bool)) or value is None:
|
||||||
|
return str(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
return f"[{', '.join(format_value(v) for v in value)}]"
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported type: {type(value)}")
|
||||||
|
|
||||||
|
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
||||||
|
return f"[{fname}({args_str})]"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
358
llama_stack/models/llama/llama3_1/prompt_format.md
Normal file
358
llama_stack/models/llama/llama3_1/prompt_format.md
Normal file
|
@ -0,0 +1,358 @@
|
||||||
|
|
||||||
|
|
||||||
|
# Llama 3.1 - Prompt Formats
|
||||||
|
## Tokens
|
||||||
|
Here is a list of special tokens that are supported by Llama 3.1:
|
||||||
|
- `<|begin_of_text|>`: Specifies the start of the prompt
|
||||||
|
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
|
||||||
|
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
|
||||||
|
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and ipython]
|
||||||
|
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
|
||||||
|
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
|
||||||
|
- at the end of a direct interaction between the model and the user
|
||||||
|
- at the end of multiple interactions between the model and any available tools
|
||||||
|
This token signals to the executor that the model has finished generating a response.
|
||||||
|
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
There are 4 different roles that are supported by Llama 3.1
|
||||||
|
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
|
||||||
|
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
|
||||||
|
- `ipython`: A new role introduced in Llama 3.1. Semantically, this role means "tool". This role is used to mark messages with the output of a tool call when sent back to the model from the executor.
|
||||||
|
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `ipython` and `user` prompts.
|
||||||
|
|
||||||
|
## Llama 3.1 Base Model
|
||||||
|
|
||||||
|
Text completion for Llama 3.1 base model uses this format.
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|>Color of sky is blue but sometimes can also be
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
red, orange, yellow, green, purple, pink, brown, gray, black, white, and even rainbow colors. The color of the sky can change due to various reasons such as time of day, weather conditions, pollution, and atmospheric phenomena.
|
||||||
|
The color of the sky is primarily blue because of a phenomenon called
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Note start special tag
|
||||||
|
|
||||||
|
|
||||||
|
## Llama 3.1 Instruct Model
|
||||||
|
## User and assistant conversation
|
||||||
|
|
||||||
|
Here is a regular multi-turn user assistant conversation and how its formatted.
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Answer who are you in the form of jeopardy?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
Here's my response
|
||||||
|
|
||||||
|
"What is a helpful assistant?"<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Tool Calling Formats
|
||||||
|
|
||||||
|
|
||||||
|
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
|
||||||
|
- Brave Search: Tool call to perform web searches.
|
||||||
|
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
|
||||||
|
- Code Interpreter: Enables the model to output python code.
|
||||||
|
|
||||||
|
## Builtin Tool Calling
|
||||||
|
|
||||||
|
|
||||||
|
Here is an example of a conversation using brave search
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython
|
||||||
|
Tools: brave_search, wolfram_alpha
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: 21 September 2024
|
||||||
|
|
||||||
|
You are a helpful assistant.
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
|
||||||
|
- The message body of the assistant response starts with a special tag <|python_tag|>
|
||||||
|
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
|
||||||
|
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
|
||||||
|
|
||||||
|
|
||||||
|
## Builtin Code Interpreter
|
||||||
|
|
||||||
|
Here is an actual example of model responding with code
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Write code to check if number is prime, use that to see if the number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
<|python_tag|>def is_prime(n):
|
||||||
|
if n <= 1
|
||||||
|
return False
|
||||||
|
for i in range(2, int(n**0.5) + 1):
|
||||||
|
if n % i == 0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
print(is_prime(7)) # Output: True<|eom_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
|
||||||
|
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
|
||||||
|
|
||||||
|
|
||||||
|
## Built-in tools full interaction
|
||||||
|
|
||||||
|
Here is a full interaction with the built-in tools including the tool response and the final assistant response.
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython
|
||||||
|
Tools: brave_search, wolfram_alpha
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What is the 100th decimal of pi?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
<|python_tag|>wolfram_alpha.call(query="100th decimal of pi")<|eom_id|><|start_header_id|>ipython<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
{
|
||||||
|
"queryresult": {
|
||||||
|
"success": true,
|
||||||
|
"inputstring": "100th decimal of pi",
|
||||||
|
"pods": [
|
||||||
|
{
|
||||||
|
"title": "Input interpretation",
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "100th digit | π"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Nearby digits",
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "...86208998628034825342117067982148086513282306647093..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Result",
|
||||||
|
"primary": true,
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "7"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
The 100th decimal of pi is 7.<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- Note the `<|python_tag|>` in the assistant response.
|
||||||
|
- Role is `ipython` for the wolfram alpha response that is passed back to the model.
|
||||||
|
- Final message from assistant has <|eot_id|> tag.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Zero shot tool calling
|
||||||
|
## JSON based tool calling
|
||||||
|
|
||||||
|
|
||||||
|
Llama models can now output custom tool calls from a single message to allow easier tool calling.
|
||||||
|
The following prompts provide an example of how custom tools can be called from the output of the model.
|
||||||
|
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython
|
||||||
|
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: 21 September 2024
|
||||||
|
|
||||||
|
You are a helpful assistant.
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Answer the user's question by making use of the following functions if needed.
|
||||||
|
If none of the function can be used, please say so.
|
||||||
|
Here is a list of functions in JSON format:
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "trending_songs",
|
||||||
|
"description": "Returns the trending songs on a Music site",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": [
|
||||||
|
{
|
||||||
|
"n": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The number of songs to return"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"genre": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The genre of the songs to return"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"required": ["n"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Return function calls in JSON format.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
<|python_tag|>{
|
||||||
|
"type": "function",
|
||||||
|
"name": "trending_songs",
|
||||||
|
"parameters": {
|
||||||
|
"n": "10",
|
||||||
|
"genre": "all"
|
||||||
|
}
|
||||||
|
}<|eom_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- JSON format for providing tools needs name, description and parameters
|
||||||
|
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
|
||||||
|
- Instructions for tools added as a user message
|
||||||
|
- Only single tool calls are supported as of now
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Example of a user defined tool calling
|
||||||
|
## `<function>` based tool calling
|
||||||
|
|
||||||
|
|
||||||
|
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
|
||||||
|
In this example, we define a custom tool calling format using the `<function>` tag.
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython
|
||||||
|
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: 21 September 2024
|
||||||
|
|
||||||
|
You are a helpful assistant.
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
You have access to the following functions:
|
||||||
|
|
||||||
|
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||||
|
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||||
|
|
||||||
|
Think very carefully before calling functions.
|
||||||
|
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||||
|
|
||||||
|
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||||
|
|
||||||
|
Reminder:
|
||||||
|
- If looking for real time information use relevant functions before falling back to brave_search
|
||||||
|
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||||
|
- Required parameters MUST be specified
|
||||||
|
- Only call one function at a time
|
||||||
|
- Put the entire function call reply on one line<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
<function=trending_songs>{"n": 10}</function><|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
|
||||||
|
- Instructions for tools added as a user message
|
||||||
|
|
||||||
|
|
||||||
|
Thank You!
|
|
@ -14,7 +14,7 @@
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
import json
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
import textwrap
|
import textwrap
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
RawTextItem,
|
RawTextItem,
|
||||||
|
|
286
llama_stack/models/llama/llama3_2/text_prompt_format.md
Normal file
286
llama_stack/models/llama/llama3_2/text_prompt_format.md
Normal file
|
@ -0,0 +1,286 @@
|
||||||
|
## User and assistant conversation
|
||||||
|
|
||||||
|
Here is a regular multi-turn user assistant conversation and how its formatted.
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
I'm an AI assistant, which means I'm a computer program designed to simulate conversation and answer questions to the best of my ability. I'm here to help you with any questions or tasks you may have, and I'll do my best to provide you with accurate and helpful information.
|
||||||
|
|
||||||
|
I don't have a personal name, but I'm often referred to as a "virtual assistant" or a "chatbot." I'm a machine learning model, which means I've been trained on a large dataset of text and can generate responses based on patterns and context.
|
||||||
|
|
||||||
|
I can help with a wide range of topics, from general knowledge and trivia to more specialized subjects like science, history, and technology. I can also assist with tasks like language translation, text summarization, and even generating creative content like stories or poetry.
|
||||||
|
|
||||||
|
So, what can I help you with today?<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
This format is unchanged from Llama3.1
|
||||||
|
|
||||||
|
## Zero shot function calling
|
||||||
|
|
||||||
|
|
||||||
|
For Llama3.2 1B and 3B instruct models, we are introducing a new format for zero shot function calling.
|
||||||
|
This new format is designed to be more flexible and powerful than the previous format.
|
||||||
|
All available functions can be provided in the system message. A key difference is in the format of how the assistant responds with function calls.
|
||||||
|
It is pythonic in the form of `[func1(params_name=params_value, params_name2=params_value2...), func2(params)]` instead of the `json` or `<function>` tag that were defined in Llama3.1.
|
||||||
|
Here is an example for the same,
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||||
|
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||||
|
also point it out. You should only return the function call in tools call sections.
|
||||||
|
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather info for places",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": [
|
||||||
|
"city"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city to get the weather for"
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
"default": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What is the weather in SF and Seattle?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
[get_weather(city='San Francisco', metric='celsius'), get_weather(city='Seattle', metric='celsius')]<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
|
||||||
|
- The output supports multiple tool calls natively
|
||||||
|
- JSON format for defining the functions in the system prompt is similar to Llama3.1
|
||||||
|
|
||||||
|
|
||||||
|
## Zero shot function calling with user message
|
||||||
|
|
||||||
|
|
||||||
|
While the default is to provide all function calls in a system message, in Llama3.2 text models you can also provide information for all the available tools in a user message.
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
|
||||||
|
Here is a list of functions in JSON format that you can invoke:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_user_info",
|
||||||
|
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": [
|
||||||
|
"user_id"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"user_id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
|
||||||
|
},
|
||||||
|
"special": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Any special information or parameters that need to be considered while fetching user details.",
|
||||||
|
"default": "none"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
|
||||||
|
|
||||||
|
NO other text MUST be included.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
[get_user_info(user_id=7890, special='black')]<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
|
||||||
|
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
|
||||||
|
- While builtin tool calls end with a <|eom_id|>, notice the <|eot_id|> for zero shot tool calls.
|
||||||
|
|
||||||
|
|
||||||
|
## Code Interpreter
|
||||||
|
|
||||||
|
|
||||||
|
Code Interpreter continues to work in 3.2 text models similar to Llama 3.1 model family.
|
||||||
|
Here is an example,
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython
|
||||||
|
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: 24 September 2024
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Write code to check if number is prime. Use it to verify if number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
<|python_tag|>def is_prime(n):
|
||||||
|
if n <= 1:
|
||||||
|
return False
|
||||||
|
if n == 2:
|
||||||
|
return True
|
||||||
|
if n % 2 == 0:
|
||||||
|
return False
|
||||||
|
max_divisor = int(n**0.5) + 1
|
||||||
|
for d in range(3, max_divisor, 2):
|
||||||
|
if n % d == 0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
print(is_prime(7)) # Output: True<|eom_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
|
||||||
|
- Note `Environment: ipython` in the system prompt.
|
||||||
|
- Note that the response starts with `<|python_tag|>` and ends with `<|eom_id|>`
|
||||||
|
|
||||||
|
|
||||||
|
## Zero shot function calling E2E format
|
||||||
|
|
||||||
|
|
||||||
|
Here is an example of the e2e cycle of tool calls with the model in a muti-step way.
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||||
|
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||||
|
also point it out. You should only return the function call in tools call sections.
|
||||||
|
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather info for places",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": [
|
||||||
|
"city"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city to get the weather for"
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
"default": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What is the weather in SF?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
<|python_tag|>[get_weather(city="San Francisco", metric="celsius")]<|eot_id|><|start_header_id|>ipython<|end_header_id|>
|
||||||
|
|
||||||
|
"25 C"<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
The weather in San Francisco is 25 C.<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
|
||||||
|
- The output of the function call is provided back to the model as a tool response ( in json format ).
|
||||||
|
- Notice `<|start_header_id|>ipython<|end_header_id|>` as the header message preceding the tool response.
|
||||||
|
- The model finally summarizes the information from the tool response and returns the result to the user.
|
||||||
|
|
||||||
|
|
||||||
|
## Prompt format for base models
|
||||||
|
|
||||||
|
|
||||||
|
For base models (Llama3.2-1B and Llama3.2-3B), the prompt format for a simple completion is as follows
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|>The color of the sky is blue but sometimes it can also be
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
gray or even purple. The color of the sky can change depending on the time of day, the weather, and the amount of pollution in the air. The color of the sky can also be affected by the presence of dust, smoke, and other particles in the air.
|
||||||
|
|
||||||
|
## Step 1: Identify the factors that
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
Same as Llama3.1
|
||||||
|
|
||||||
|
Thank You!
|
141
llama_stack/models/llama/llama3_2/vision_prompt_format.md
Normal file
141
llama_stack/models/llama/llama3_2/vision_prompt_format.md
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
## User and assistant conversation
|
||||||
|
|
||||||
|
Here is a regular multi-turn user assistant conversation and how its formatted.
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
I'm a helpful assistant, here to provide information, answer questions, and assist with tasks to the best of my abilities. I'm a large language model, which means I can understand and respond to natural language inputs, and I'm constantly learning and improving to provide more accurate and helpful responses.
|
||||||
|
|
||||||
|
I can help with a wide range of topics, from general knowledge and trivia to more specific areas like science, history, technology, and more. I can also assist with tasks like language translation, text summarization, and even generating creative content like stories or dialogues.
|
||||||
|
|
||||||
|
What can I help you with today?<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
This format is unchanged from Llama3.1
|
||||||
|
|
||||||
|
## User and assistant conversation with Images
|
||||||
|
|
||||||
|
This example shows how to pass and image to the model as part of the messages.
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
<|image|>Describe this image in two sentences<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
The image depicts a small dog standing on a skateboard, with its front paws firmly planted on the board and its back paws slightly raised. The dog's fur is predominantly brown and white, with a distinctive black stripe running down its back, and it is wearing a black collar around its neck.<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
|
||||||
|
- The `<|image|>` tag is used to indicate presence of the image
|
||||||
|
- The model isn't an early fusion model so doesn't actually translate an image into several tokens. Instead the cross-attention layers take input "on the side" from a vision encoder
|
||||||
|

|
||||||
|
- Its important to postion the <|image|> tag appropriately in the prompt. Image will only attend to the subsequent text tokens
|
||||||
|
- The <|image|> tag is part of the user message body, implying that it should only come after the header `<|start_header_id|>{role}<|end_header_id|>` in the message body
|
||||||
|
- We recommend using a single image in one prompt
|
||||||
|
|
||||||
|
|
||||||
|
## Builtin and Zero Shot Tool Calling
|
||||||
|
|
||||||
|
|
||||||
|
Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only.
|
||||||
|
Use `Environment: ipython` to enable tools.
|
||||||
|
Add `Tools: {{tool_name1}},{{tool_name2}}` for each of the builtin tools.
|
||||||
|
The same builtin tools as Llama3.1 are available,
|
||||||
|
- code_interpreter (for executing python code)
|
||||||
|
- brave_search (to search the web)
|
||||||
|
- wolfram_alpha (for querying wolfram alpha for mathematical questions)
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython
|
||||||
|
Tools: brave_search, wolfram_alpha
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: 23 September 2024
|
||||||
|
|
||||||
|
You are a helpful assistant.
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
|
||||||
|
- Note the `<|python_tag|>` before `brave_search` function call.
|
||||||
|
- The `<|eom_id|>` tag is used to indicate the end of the message.
|
||||||
|
- Similar to Llama3.1, code_interpreter is not explicitly mentioned but is enabled via `Environment: ipython`.
|
||||||
|
- Tool Calling does NOT work with images in the prompt as of now.
|
||||||
|
|
||||||
|
|
||||||
|
## Prompt format for base models
|
||||||
|
|
||||||
|
|
||||||
|
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), the prompt format for a simple completion is as follows
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|>The color of the sky is blue but sometimes it can also be
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
red, orange, pink, purple, and even black. The color of the sky is determined by the amount of sunlight that is scattered by the atmosphere and the amount of dust and water vapor present in the atmosphere. During sunrise and sunset, the sky can take on a range of colors due to the scattering of light by
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
- Same as Llama3.1
|
||||||
|
|
||||||
|
## Prompt format for base models with Image
|
||||||
|
|
||||||
|
|
||||||
|
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), here is an example of how the text completion format looks with an image,
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|image|>If I had to write a haiku for this one
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
, it would be: A skateboarder's delight, a puppy on a board, a furry little thrill-seeker. This puppy is a true skateboarding enthusiast, always eager to hit the streets and show off his skills. He's a master of the board, gliding effortlessly across the pavement with grace and style.
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
- Note the placement of the special tags <|begin_of_text|> and <|image|>
|
||||||
|
|
||||||
|
Thank You!
|
|
@ -14,7 +14,7 @@
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
|
|
@ -16,7 +16,9 @@ import textwrap
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
RawContent,
|
RawContent,
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
|
@ -25,7 +27,6 @@ from llama_models.datatypes import (
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from .llama3.interface import LLama31Interface
|
from .llama3.interface import LLama31Interface
|
||||||
from .llama3.template_data import (
|
from .llama3.template_data import (
|
||||||
|
|
|
@ -61,7 +61,12 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import RAGDocument, ToolGroups, ToolInvocationResult, ToolRuntime
|
from llama_stack.apis.tools import (
|
||||||
|
RAGDocument,
|
||||||
|
ToolGroups,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
@ -120,13 +125,25 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
# We do not want to keep adding RAG context to the input messages
|
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||||
# May be this should be a parameter of the agentic instance
|
tool_call_ids = set()
|
||||||
# that can define its behavior in a custom way
|
for step in turn.steps:
|
||||||
|
if step.step_type == StepType.tool_execution.value:
|
||||||
|
for response in step.tool_responses:
|
||||||
|
tool_call_ids.add(response.call_id)
|
||||||
|
|
||||||
for m in turn.input_messages:
|
for m in turn.input_messages:
|
||||||
msg = m.model_copy()
|
msg = m.model_copy()
|
||||||
|
# We do not want to keep adding RAG context to the input messages
|
||||||
|
# May be this should be a parameter of the agentic instance
|
||||||
|
# that can define its behavior in a custom way
|
||||||
if isinstance(msg, UserMessage):
|
if isinstance(msg, UserMessage):
|
||||||
msg.context = None
|
msg.context = None
|
||||||
|
if isinstance(msg, ToolResponseMessage):
|
||||||
|
if msg.call_id in tool_call_ids:
|
||||||
|
# NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps
|
||||||
|
continue
|
||||||
|
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
||||||
for step in turn.steps:
|
for step in turn.steps:
|
||||||
|
@ -260,17 +277,24 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
if len(turns) == 0:
|
||||||
|
raise ValueError("No turns found for session")
|
||||||
|
|
||||||
messages = await self.get_messages_from_turns(turns)
|
messages = await self.get_messages_from_turns(turns)
|
||||||
messages.extend(request.tool_responses)
|
messages.extend(request.tool_responses)
|
||||||
|
|
||||||
|
last_turn = turns[-1]
|
||||||
|
last_turn_messages = self.turn_to_messages(last_turn)
|
||||||
last_turn_messages = [
|
last_turn_messages = [
|
||||||
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# TODO: figure out whether we should add the tool responses to the last turn messages
|
||||||
|
last_turn_messages.extend(request.tool_responses)
|
||||||
|
|
||||||
# get the steps from the turn id
|
# get the steps from the turn id
|
||||||
steps = []
|
steps = []
|
||||||
if len(turns) > 0:
|
steps = turns[-1].steps
|
||||||
steps = turns[-1].steps
|
|
||||||
|
|
||||||
# mark tool execution step as complete
|
# mark tool execution step as complete
|
||||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||||
|
@ -509,9 +533,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if documents:
|
if documents:
|
||||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||||
|
|
||||||
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
|
if session_info and session_info.vector_db_id:
|
||||||
|
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||||
|
|
||||||
# Build a map of custom tools to their definitions for faster lookup
|
# Build a map of custom tools to their definitions for faster lookup
|
||||||
client_tools = {}
|
client_tools = {}
|
||||||
for tool in self.agent_config.client_tools:
|
for tool in self.agent_config.client_tools:
|
||||||
|
@ -586,8 +616,20 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if event.stop_reason is not None:
|
if event.stop_reason is not None:
|
||||||
stop_reason = event.stop_reason
|
stop_reason = event.stop_reason
|
||||||
span.set_attribute("stop_reason", stop_reason)
|
span.set_attribute("stop_reason", stop_reason)
|
||||||
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
|
span.set_attribute(
|
||||||
span.set_attribute("output", f"content: {content} tool_calls: {tool_calls}")
|
"input",
|
||||||
|
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||||
|
)
|
||||||
|
output_attr = json.dumps(
|
||||||
|
{
|
||||||
|
"content": content,
|
||||||
|
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
span.set_attribute("output", output_attr)
|
||||||
|
|
||||||
|
n_iter += 1
|
||||||
|
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
|
||||||
|
|
||||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||||
|
|
||||||
|
@ -624,6 +666,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
if n_iter >= self.agent_config.max_infer_iters:
|
if n_iter >= self.agent_config.max_infer_iters:
|
||||||
log.info("Done with MAX iterations, exiting.")
|
log.info("Done with MAX iterations, exiting.")
|
||||||
|
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||||
|
# Do not continue the tool call loop after this point
|
||||||
|
message.stop_reason = StopReason.end_of_turn
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -673,6 +718,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# If tool is a client tool, yield CompletionMessage and return
|
# If tool is a client tool, yield CompletionMessage and return
|
||||||
if tool_call.tool_name in client_tools:
|
if tool_call.tool_name in client_tools:
|
||||||
|
# NOTE: mark end_of_message to indicate to client that it may
|
||||||
|
# call the tool and continue the conversation with the tool's response.
|
||||||
|
message.stop_reason = StopReason.end_of_message
|
||||||
await self.storage.set_in_progress_tool_call_step(
|
await self.storage.set_in_progress_tool_call_step(
|
||||||
session_id,
|
session_id,
|
||||||
turn_id,
|
turn_id,
|
||||||
|
@ -758,16 +806,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
input_messages = input_messages + [message, result_message]
|
input_messages = input_messages + [message, result_message]
|
||||||
|
|
||||||
n_iter += 1
|
|
||||||
|
|
||||||
async def _get_tool_defs(
|
async def _get_tool_defs(
|
||||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||||
# Determine which tools to include
|
# Determine which tools to include
|
||||||
agent_config_toolgroups = set(
|
agent_config_toolgroups = {
|
||||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
|
||||||
for toolgroup in self.agent_config.toolgroups
|
for toolgroup in self.agent_config.toolgroups
|
||||||
)
|
}
|
||||||
toolgroups_for_turn_set = (
|
toolgroups_for_turn_set = (
|
||||||
agent_config_toolgroups
|
agent_config_toolgroups
|
||||||
if toolgroups_for_turn is None
|
if toolgroups_for_turn is None
|
||||||
|
@ -803,6 +849,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||||
|
if not tools.data:
|
||||||
|
available_tool_groups = ", ".join(
|
||||||
|
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
||||||
|
)
|
||||||
|
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
||||||
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
|
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||||
|
@ -1025,9 +1076,6 @@ async def execute_tool_call_maybe(
|
||||||
group_name = tool_to_group.get(name, None)
|
group_name = tool_to_group.get(name, None)
|
||||||
if group_name is None:
|
if group_name is None:
|
||||||
raise ValueError(f"Tool {name} not found in any tool group")
|
raise ValueError(f"Tool {name} not found in any tool group")
|
||||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
|
||||||
tool_call_args = tool_call.arguments
|
|
||||||
tool_call_args.update(toolgroup_args.get(group_name, {}))
|
|
||||||
if isinstance(name, BuiltinTool):
|
if isinstance(name, BuiltinTool):
|
||||||
if name == BuiltinTool.brave_search:
|
if name == BuiltinTool.brave_search:
|
||||||
name = WEB_SEARCH_TOOL
|
name = WEB_SEARCH_TOOL
|
||||||
|
@ -1036,10 +1084,12 @@ async def execute_tool_call_maybe(
|
||||||
|
|
||||||
result = await tool_runtime_api.invoke_tool(
|
result = await tool_runtime_api.invoke_tool(
|
||||||
tool_name=name,
|
tool_name=name,
|
||||||
kwargs=dict(
|
kwargs={
|
||||||
session_id=session_id,
|
"session_id": session_id,
|
||||||
**tool_call_args,
|
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||||
),
|
**tool_call.arguments,
|
||||||
|
**toolgroup_args.get(group_name, {}),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
@ -194,17 +194,13 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
agent = await self.get_agent(agent_id)
|
||||||
turn = json.loads(turn)
|
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||||
turn = Turn(**turn)
|
|
||||||
return turn
|
return turn
|
||||||
|
|
||||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
|
||||||
turn = json.loads(turn)
|
for step in turn.steps:
|
||||||
turn = Turn(**turn)
|
|
||||||
steps = turn.steps
|
|
||||||
for step in steps:
|
|
||||||
if step.step_id == step_id:
|
if step.step_id == step_id:
|
||||||
return AgentStepResponse(step=step)
|
return AgentStepResponse(step=step)
|
||||||
raise ValueError(f"Provided step_id {step_id} could not be found")
|
raise ValueError(f"Provided step_id {step_id} could not be found")
|
||||||
|
@ -215,20 +211,18 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_ids: Optional[List[str]] = None,
|
turn_ids: Optional[List[str]] = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
|
agent = await self.get_agent(agent_id)
|
||||||
session = Session(**json.loads(session), turns=[])
|
session_info = await agent.storage.get_session_info(session_id)
|
||||||
turns = []
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
turns = await agent.storage.get_session_turns(session_id)
|
||||||
if turn_ids:
|
if turn_ids:
|
||||||
for turn_id in turn_ids:
|
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
|
||||||
turn = json.loads(turn)
|
|
||||||
turn = Turn(**turn)
|
|
||||||
turns.append(turn)
|
|
||||||
return Session(
|
return Session(
|
||||||
session_name=session.session_name,
|
session_name=session_info.session_name,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
turns=turns if turns else [],
|
turns=turns,
|
||||||
started_at=session.started_at,
|
started_at=session_info.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||||
|
|
|
@ -21,6 +21,7 @@ log = logging.getLogger(__name__)
|
||||||
class AgentSessionInfo(BaseModel):
|
class AgentSessionInfo(BaseModel):
|
||||||
session_id: str
|
session_id: str
|
||||||
session_name: str
|
session_name: str
|
||||||
|
# TODO: is this used anywhere?
|
||||||
vector_db_id: Optional[str] = None
|
vector_db_id: Optional[str] = None
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
|
@ -85,6 +86,14 @@ class AgentPersistence:
|
||||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
)
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
return Turn(**json.loads(value))
|
||||||
|
|
||||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
@ -96,3 +105,15 @@ class AgentPersistence:
|
||||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
)
|
)
|
||||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||||
|
|
||||||
|
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
value=str(num_infer_iters),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
)
|
||||||
|
return int(value) if value else None
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
@ -86,7 +87,6 @@ class MetaReferenceEvalImpl(
|
||||||
) -> Job:
|
) -> Job:
|
||||||
task_def = self.benchmarks[benchmark_id]
|
task_def = self.benchmarks[benchmark_id]
|
||||||
dataset_id = task_def.dataset_id
|
dataset_id = task_def.dataset_id
|
||||||
candidate = task_config.eval_candidate
|
|
||||||
scoring_functions = task_def.scoring_functions
|
scoring_functions = task_def.scoring_functions
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||||
|
@ -117,7 +117,7 @@ class MetaReferenceEvalImpl(
|
||||||
generations = []
|
generations = []
|
||||||
for i, x in tqdm(enumerate(input_rows)):
|
for i, x in tqdm(enumerate(input_rows)):
|
||||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||||
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
|
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||||
input_messages = [UserMessage(**x) for x in input_messages]
|
input_messages = [UserMessage(**x) for x in input_messages]
|
||||||
|
|
||||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||||
|
@ -159,7 +159,7 @@ class MetaReferenceEvalImpl(
|
||||||
generations = []
|
generations = []
|
||||||
for x in tqdm(input_rows):
|
for x in tqdm(input_rows):
|
||||||
if ColumnName.completion_input.value in x:
|
if ColumnName.completion_input.value in x:
|
||||||
input_content = eval(str(x[ColumnName.completion_input.value]))
|
input_content = json.loads(x[ColumnName.completion_input.value])
|
||||||
response = await self.inference_api.completion(
|
response = await self.inference_api.completion(
|
||||||
model=candidate.model,
|
model=candidate.model,
|
||||||
content=input_content,
|
content=input_content,
|
||||||
|
@ -167,9 +167,8 @@ class MetaReferenceEvalImpl(
|
||||||
)
|
)
|
||||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||||
elif ColumnName.chat_completion_input.value in x:
|
elif ColumnName.chat_completion_input.value in x:
|
||||||
chat_completion_input_str = str(x[ColumnName.chat_completion_input.value])
|
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||||
input_messages = eval(chat_completion_input_str)
|
input_messages = [UserMessage(**x) for x in chat_completion_input_json]
|
||||||
input_messages = [UserMessage(**x) for x in input_messages]
|
|
||||||
messages = []
|
messages = []
|
||||||
if candidate.system_message:
|
if candidate.system_message:
|
||||||
messages.append(candidate.system_message)
|
messages.append(candidate.system_message)
|
||||||
|
|
|
@ -23,13 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from llama_models.llama3.reference_impl.model import Transformer
|
|
||||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
|
||||||
CrossAttentionTransformer,
|
|
||||||
)
|
|
||||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -46,6 +39,13 @@ from llama_stack.models.llama.datatypes import (
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.llama3.args import ModelArgs
|
||||||
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput
|
||||||
|
from llama_stack.models.llama.llama3.model import Transformer
|
||||||
|
from llama_stack.models.llama.llama3.multimodal.model import (
|
||||||
|
CrossAttentionTransformer,
|
||||||
|
)
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
|
|
|
@ -111,7 +111,7 @@ class MetaReferenceInferenceImpl(
|
||||||
)
|
)
|
||||||
if llama_model is None:
|
if llama_model is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
|
"Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_registry_helper = ModelRegistryHelper(
|
self.model_registry_helper = ModelRegistryHelper(
|
||||||
|
@ -208,7 +208,6 @@ class MetaReferenceInferenceImpl(
|
||||||
logprobs = []
|
logprobs = []
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
|
||||||
for token_result in self.generator.completion(request):
|
for token_result in self.generator.completion(request):
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
if token_result.text == "<|eot_id|>":
|
if token_result.text == "<|eot_id|>":
|
||||||
|
|
|
@ -9,10 +9,9 @@ from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Generator
|
from typing import Any, Generator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import Model
|
from llama_stack.models.llama.datatypes import Model
|
||||||
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
|
|
|
@ -207,7 +207,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
|
||||||
return parse_message(maybe_json)
|
return parse_message(maybe_json)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return None
|
return None
|
||||||
except ValueError as e:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -352,7 +352,7 @@ class ModelParallelProcessGroup:
|
||||||
if isinstance(obj, TaskResponse):
|
if isinstance(obj, TaskResponse):
|
||||||
yield obj.result
|
yield obj.result
|
||||||
|
|
||||||
except GeneratorExit as e:
|
except GeneratorExit:
|
||||||
self.request_socket.send(encode_msg(CancelSentinel()))
|
self.request_socket.send(encode_msg(CancelSentinel()))
|
||||||
while True:
|
while True:
|
||||||
obj_json = self.request_socket.send()
|
obj_json = self.request_socket.send()
|
||||||
|
|
|
@ -7,6 +7,9 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
# The file gets a special treatment for now?
|
||||||
|
# ruff: noqa: N803
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -15,13 +15,13 @@ import torch
|
||||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
|
||||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||||
|
|
||||||
from llama_stack.apis.inference import QuantizationType
|
from llama_stack.apis.inference import QuantizationType
|
||||||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
|
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
|
||||||
|
from llama_stack.models.llama.llama3.args import ModelArgs
|
||||||
|
from llama_stack.models.llama.llama3.model import Transformer, TransformerBlock
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
|
||||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||||
|
|
|
@ -22,11 +22,11 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from llama_stack.models.llama.llama3.args import ModelArgs
|
||||||
|
from llama_stack.models.llama.llama3.model import Transformer, TransformerBlock
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
|
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
|
||||||
quantize_fp8,
|
quantize_fp8,
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,7 +21,7 @@ NPROC=$7
|
||||||
|
|
||||||
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
||||||
|
|
||||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-models:/home/$USER/llama-stack" \
|
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-stack" \
|
||||||
torchrun \
|
torchrun \
|
||||||
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
||||||
--rdzv_id=$RUN_ID \
|
--rdzv_id=$RUN_ID \
|
||||||
|
|
|
@ -11,5 +11,5 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
class SentenceTransformersInferenceConfig(BaseModel):
|
class SentenceTransformersInferenceConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls) -> Dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
return {}
|
return {}
|
||||||
|
|
|
@ -9,7 +9,6 @@ import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||||
|
@ -36,6 +35,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
|
|
@ -10,16 +10,19 @@
|
||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Any, Mapping
|
from typing import Any, Mapping
|
||||||
|
|
||||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||||
|
|
||||||
|
|
||||||
def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
def llama_stack_instruct_to_torchtune_instruct(
|
||||||
|
sample: Mapping[str, Any],
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
|
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
|
||||||
"Invalid input row"
|
"Invalid input row"
|
||||||
)
|
)
|
||||||
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
|
input_messages = json.loads(sample[ColumnName.chat_completion_input.value])
|
||||||
|
|
||||||
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
|
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
|
||||||
input_message = input_messages[0]
|
input_message = input_messages[0]
|
||||||
|
@ -37,7 +40,7 @@ def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Map
|
||||||
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
assert ColumnName.dialog.value in sample, "Invalid input row"
|
assert ColumnName.dialog.value in sample, "Invalid input row"
|
||||||
role_map = {"user": "human", "assistant": "gpt"}
|
role_map = {"user": "human", "assistant": "gpt"}
|
||||||
dialog = eval(str(sample[ColumnName.dialog.value]))
|
dialog = json.loads(sample[ColumnName.dialog.value])
|
||||||
|
|
||||||
assert len(dialog) > 1, "dialog must have at least 2 messagse"
|
assert len(dialog) > 1, "dialog must have at least 2 messagse"
|
||||||
roles = []
|
roles = []
|
||||||
|
|
|
@ -264,7 +264,7 @@ class LoraFinetuningSingleDevice:
|
||||||
)
|
)
|
||||||
|
|
||||||
self.adapter_params = get_adapter_params(model)
|
self.adapter_params = get_adapter_params(model)
|
||||||
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()])
|
self._is_dora = any("magnitude" in k for k in self.adapter_params.keys())
|
||||||
|
|
||||||
set_trainable_params(model, self.adapter_params)
|
set_trainable_params(model, self.adapter_params)
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue