mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 15:04:32 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
80313c8b6a
314 changed files with 170890 additions and 6513 deletions
3
.gitmodules
vendored
3
.gitmodules
vendored
|
|
@ -1,3 +0,0 @@
|
||||||
[submodule "llama_stack/providers/impls/ios/inference/executorch"]
|
|
||||||
path = llama_stack/providers/inline/ios/inference/executorch
|
|
||||||
url = https://github.com/pytorch/executorch
|
|
||||||
|
|
@ -8,6 +8,8 @@ repos:
|
||||||
rev: v5.0.0 # Latest stable version
|
rev: v5.0.0 # Latest stable version
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-merge-conflict
|
- id: check-merge-conflict
|
||||||
|
- id: trailing-whitespace
|
||||||
|
exclude: '\.py$' # Exclude Python files as Ruff already handles them
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
args: ['--maxkb=1000']
|
args: ['--maxkb=1000']
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
|
|
@ -42,8 +44,9 @@ repos:
|
||||||
- black==24.3.0
|
- black==24.3.0
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||||
rev: 0.5.26
|
rev: 0.6.3
|
||||||
hooks:
|
hooks:
|
||||||
|
- id: uv-lock
|
||||||
- id: uv-export
|
- id: uv-export
|
||||||
args: [
|
args: [
|
||||||
"--frozen",
|
"--frozen",
|
||||||
|
|
@ -51,8 +54,6 @@ repos:
|
||||||
"--no-emit-project",
|
"--no-emit-project",
|
||||||
"--output-file=requirements.txt"
|
"--output-file=requirements.txt"
|
||||||
]
|
]
|
||||||
files: ^pyproject\.toml$
|
|
||||||
- id: uv-sync
|
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.15.0
|
rev: v1.15.0
|
||||||
|
|
@ -84,15 +85,12 @@ repos:
|
||||||
- id: distro-codegen
|
- id: distro-codegen
|
||||||
name: Distribution Template Codegen
|
name: Distribution Template Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- rich
|
|
||||||
- pydantic
|
|
||||||
- uv==0.6.0
|
- uv==0.6.0
|
||||||
entry: uv run python -m llama_stack.scripts.distro_codegen
|
entry: uv run --extra codegen python -m llama_stack.scripts.distro_codegen
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^llama_stack/templates/.*$
|
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
||||||
files: ^llama_stack/providers/.*/inference/.*/models\.py$
|
|
||||||
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||||
|
|
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
3.10
|
||||||
|
|
@ -70,6 +70,19 @@ $ uv pip install -e .
|
||||||
$ source .venv/bin/activate
|
$ source .venv/bin/activate
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Note that you can create a dotenv file `.env` that includes necessary environment variables:
|
||||||
|
```
|
||||||
|
LLAMA_STACK_BASE_URL=http://localhost:8321
|
||||||
|
LLAMA_STACK_CLIENT_LOG=debug
|
||||||
|
LLAMA_STACK_PORT=8321
|
||||||
|
LLAMA_STACK_CONFIG=
|
||||||
|
```
|
||||||
|
|
||||||
|
And then use this dotenv file when running client SDK tests via the following:
|
||||||
|
```bash
|
||||||
|
$ uv run --env-file .env -- pytest -v tests/api/inference/test_text_inference.py
|
||||||
|
```
|
||||||
|
|
||||||
## Pre-commit Hooks
|
## Pre-commit Hooks
|
||||||
|
|
||||||
We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running:
|
We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running:
|
||||||
|
|
@ -110,15 +123,15 @@ Some tips about common tasks you work on while contributing to Llama Stack:
|
||||||
|
|
||||||
### Using `llama stack build`
|
### Using `llama stack build`
|
||||||
|
|
||||||
Building a stack image (conda / docker) will use the production version of the `llama-stack`, `llama-models` and `llama-stack-client` packages. If you are developing with a llama-stack repository checked out and need your code to be reflected in the stack image, set `LLAMA_STACK_DIR` and `LLAMA_MODELS_DIR` to the appropriate checked out directories when running any of the `llama` CLI commands.
|
Building a stack image (conda / docker) will use the production version of the `llama-stack` and `llama-stack-client` packages. If you are developing with a llama-stack repository checked out and need your code to be reflected in the stack image, set `LLAMA_STACK_DIR` and `LLAMA_STACK_CLIENT_DIR` to the appropriate checked out directories when running any of the `llama` CLI commands.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```bash
|
```bash
|
||||||
$ cd work/
|
$ cd work/
|
||||||
$ git clone https://github.com/meta-llama/llama-stack.git
|
$ git clone https://github.com/meta-llama/llama-stack.git
|
||||||
$ git clone https://github.com/meta-llama/llama-models.git
|
$ git clone https://github.com/meta-llama/llama-stack-client-python.git
|
||||||
$ cd llama-stack
|
$ cd llama-stack
|
||||||
$ LLAMA_STACK_DIR=$(pwd) LLAMA_MODELS_DIR=../llama-models llama stack build --template <...>
|
$ LLAMA_STACK_DIR=$(pwd) LLAMA_STACK_CLIENT_DIR=../llama-stack-client-python llama stack build --template <...>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,4 +3,4 @@ include distributions/dependencies.json
|
||||||
include llama_stack/distribution/*.sh
|
include llama_stack/distribution/*.sh
|
||||||
include llama_stack/cli/scripts/*.sh
|
include llama_stack/cli/scripts/*.sh
|
||||||
include llama_stack/templates/*/*.yaml
|
include llama_stack/templates/*/*.yaml
|
||||||
include llama_stack/providers/tests/test_cases/*.json
|
include llama_stack/providers/tests/test_cases/inference/*.json
|
||||||
|
|
|
||||||
|
|
@ -136,6 +136,42 @@
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
],
|
||||||
|
"dev": [
|
||||||
|
"aiosqlite",
|
||||||
|
"autoevals",
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"chromadb-client",
|
||||||
|
"datasets",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"fireworks-ai",
|
||||||
|
"httpx",
|
||||||
|
"litellm",
|
||||||
|
"matplotlib",
|
||||||
|
"mcp",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"openai",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pymongo",
|
||||||
|
"pypdf",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"sqlite-vec",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"uvicorn",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
|
],
|
||||||
"fireworks": [
|
"fireworks": [
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"autoevals",
|
"autoevals",
|
||||||
|
|
@ -171,6 +207,37 @@
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
],
|
||||||
|
"groq": [
|
||||||
|
"aiosqlite",
|
||||||
|
"autoevals",
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"datasets",
|
||||||
|
"faiss-cpu",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"httpx",
|
||||||
|
"litellm",
|
||||||
|
"matplotlib",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"openai",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pymongo",
|
||||||
|
"pypdf",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"uvicorn"
|
||||||
|
],
|
||||||
"hf-endpoint": [
|
"hf-endpoint": [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
|
|
@ -364,6 +431,7 @@
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
|
"mcp",
|
||||||
"nltk",
|
"nltk",
|
||||||
"numpy",
|
"numpy",
|
||||||
"ollama",
|
"ollama",
|
||||||
|
|
|
||||||
1087
docs/_static/llama-stack-spec.html
vendored
1087
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
904
docs/_static/llama-stack-spec.yaml
vendored
904
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
File diff suppressed because one or more lines are too long
6436
docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb
Normal file
6436
docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -45,65 +45,7 @@
|
||||||
"id": "O9pGVlPIjpix",
|
"id": "O9pGVlPIjpix",
|
||||||
"outputId": "e1fbe723-ae31-4630-eb80-4c4f6476d56f"
|
"outputId": "e1fbe723-ae31-4630-eb80-4c4f6476d56f"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Requirement already satisfied: llama-stack in /usr/local/lib/python3.10/dist-packages (0.0.61)\n",
|
|
||||||
"Requirement already satisfied: blobfile in /usr/local/lib/python3.10/dist-packages (from llama-stack) (3.0.0)\n",
|
|
||||||
"Requirement already satisfied: fire in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.7.0)\n",
|
|
||||||
"Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.28.1)\n",
|
|
||||||
"Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.26.5)\n",
|
|
||||||
"Requirement already satisfied: llama-models>=0.0.61 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.0.61)\n",
|
|
||||||
"Requirement already satisfied: llama-stack-client>=0.0.61 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.0.61)\n",
|
|
||||||
"Requirement already satisfied: prompt-toolkit in /usr/local/lib/python3.10/dist-packages (from llama-stack) (3.0.48)\n",
|
|
||||||
"Requirement already satisfied: python-dotenv in /usr/local/lib/python3.10/dist-packages (from llama-stack) (1.0.1)\n",
|
|
||||||
"Requirement already satisfied: pydantic>=2 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.10.3)\n",
|
|
||||||
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.32.3)\n",
|
|
||||||
"Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from llama-stack) (13.9.4)\n",
|
|
||||||
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from llama-stack) (75.1.0)\n",
|
|
||||||
"Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.5.0)\n",
|
|
||||||
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (6.0.2)\n",
|
|
||||||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (3.1.4)\n",
|
|
||||||
"Requirement already satisfied: tiktoken in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (0.8.0)\n",
|
|
||||||
"Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (10.4.0)\n",
|
|
||||||
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (3.7.1)\n",
|
|
||||||
"Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (8.1.7)\n",
|
|
||||||
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (1.9.0)\n",
|
|
||||||
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (2.2.2)\n",
|
|
||||||
"Requirement already satisfied: pyaml in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (24.12.1)\n",
|
|
||||||
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (1.3.1)\n",
|
|
||||||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (4.66.6)\n",
|
|
||||||
"Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (4.12.2)\n",
|
|
||||||
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (2024.8.30)\n",
|
|
||||||
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (1.0.7)\n",
|
|
||||||
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (3.10)\n",
|
|
||||||
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx->llama-stack) (0.14.0)\n",
|
|
||||||
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2->llama-stack) (0.7.0)\n",
|
|
||||||
"Requirement already satisfied: pydantic-core==2.27.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2->llama-stack) (2.27.1)\n",
|
|
||||||
"Requirement already satisfied: pycryptodomex>=3.8 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (3.21.0)\n",
|
|
||||||
"Requirement already satisfied: urllib3<3,>=1.25.3 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (2.2.3)\n",
|
|
||||||
"Requirement already satisfied: lxml>=4.9 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (5.3.0)\n",
|
|
||||||
"Requirement already satisfied: filelock>=3.0 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (3.16.1)\n",
|
|
||||||
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->llama-stack) (2024.9.0)\n",
|
|
||||||
"Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->llama-stack) (24.2)\n",
|
|
||||||
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit->llama-stack) (0.2.13)\n",
|
|
||||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->llama-stack) (3.4.0)\n",
|
|
||||||
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->llama-stack) (3.0.0)\n",
|
|
||||||
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->llama-stack) (2.18.0)\n",
|
|
||||||
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->llama-stack-client>=0.0.61->llama-stack) (1.2.2)\n",
|
|
||||||
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->llama-stack) (0.1.2)\n",
|
|
||||||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->llama-models>=0.0.61->llama-stack) (3.0.2)\n",
|
|
||||||
"Requirement already satisfied: numpy>=1.22.4 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (1.26.4)\n",
|
|
||||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2.8.2)\n",
|
|
||||||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2024.2)\n",
|
|
||||||
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2024.2)\n",
|
|
||||||
"Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.10/dist-packages (from tiktoken->llama-models>=0.0.61->llama-stack) (2024.9.11)\n",
|
|
||||||
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->llama-stack-client>=0.0.61->llama-stack) (1.17.0)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# NBVAL_SKIP\n",
|
"# NBVAL_SKIP\n",
|
||||||
"!pip install -U llama-stack"
|
"!pip install -U llama-stack"
|
||||||
|
|
@ -120,198 +62,10 @@
|
||||||
"id": "JQpLUSNjlGAM",
|
"id": "JQpLUSNjlGAM",
|
||||||
"outputId": "2f7fec97-5511-4cae-d51e-6d262fbca19c"
|
"outputId": "2f7fec97-5511-4cae-d51e-6d262fbca19c"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Requirement already satisfied: llama-stack in /usr/local/lib/python3.10/dist-packages (0.0.61)\r\n",
|
|
||||||
"Requirement already satisfied: blobfile in /usr/local/lib/python3.10/dist-packages (from llama-stack) (3.0.0)\r\n",
|
|
||||||
"Requirement already satisfied: fire in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.7.0)\r\n",
|
|
||||||
"Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.28.1)\r\n",
|
|
||||||
"Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.26.5)\r\n",
|
|
||||||
"Requirement already satisfied: llama-models>=0.0.61 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.0.61)\r\n",
|
|
||||||
"Requirement already satisfied: llama-stack-client>=0.0.61 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (0.0.61)\r\n",
|
|
||||||
"Requirement already satisfied: prompt-toolkit in /usr/local/lib/python3.10/dist-packages (from llama-stack) (3.0.48)\r\n",
|
|
||||||
"Requirement already satisfied: python-dotenv in /usr/local/lib/python3.10/dist-packages (from llama-stack) (1.0.1)\r\n",
|
|
||||||
"Requirement already satisfied: pydantic>=2 in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.10.3)\r\n",
|
|
||||||
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.32.3)\r\n",
|
|
||||||
"Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from llama-stack) (13.9.4)\r\n",
|
|
||||||
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from llama-stack) (75.1.0)\r\n",
|
|
||||||
"Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from llama-stack) (2.5.0)\r\n",
|
|
||||||
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (6.0.2)\r\n",
|
|
||||||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (3.1.4)\r\n",
|
|
||||||
"Requirement already satisfied: tiktoken in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (0.8.0)\r\n",
|
|
||||||
"Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from llama-models>=0.0.61->llama-stack) (10.4.0)\r\n",
|
|
||||||
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (3.7.1)\r\n",
|
|
||||||
"Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (8.1.7)\r\n",
|
|
||||||
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (1.9.0)\r\n",
|
|
||||||
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (2.2.2)\r\n",
|
|
||||||
"Requirement already satisfied: pyaml in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (24.12.1)\r\n",
|
|
||||||
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (1.3.1)\r\n",
|
|
||||||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (4.66.6)\r\n",
|
|
||||||
"Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from llama-stack-client>=0.0.61->llama-stack) (4.12.2)\r\n",
|
|
||||||
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (2024.8.30)\r\n",
|
|
||||||
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (1.0.7)\r\n",
|
|
||||||
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx->llama-stack) (3.10)\r\n",
|
|
||||||
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx->llama-stack) (0.14.0)\r\n",
|
|
||||||
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2->llama-stack) (0.7.0)\r\n",
|
|
||||||
"Requirement already satisfied: pydantic-core==2.27.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2->llama-stack) (2.27.1)\r\n",
|
|
||||||
"Requirement already satisfied: pycryptodomex>=3.8 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (3.21.0)\r\n",
|
|
||||||
"Requirement already satisfied: urllib3<3,>=1.25.3 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (2.2.3)\r\n",
|
|
||||||
"Requirement already satisfied: lxml>=4.9 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (5.3.0)\r\n",
|
|
||||||
"Requirement already satisfied: filelock>=3.0 in /usr/local/lib/python3.10/dist-packages (from blobfile->llama-stack) (3.16.1)\r\n",
|
|
||||||
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->llama-stack) (2024.9.0)\r\n",
|
|
||||||
"Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->llama-stack) (24.2)\r\n",
|
|
||||||
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit->llama-stack) (0.2.13)\r\n",
|
|
||||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->llama-stack) (3.4.0)\r\n",
|
|
||||||
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->llama-stack) (3.0.0)\r\n",
|
|
||||||
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->llama-stack) (2.18.0)\r\n",
|
|
||||||
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->llama-stack-client>=0.0.61->llama-stack) (1.2.2)\n",
|
|
||||||
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->llama-stack) (0.1.2)\n",
|
|
||||||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->llama-models>=0.0.61->llama-stack) (3.0.2)\n",
|
|
||||||
"Requirement already satisfied: numpy>=1.22.4 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (1.26.4)\n",
|
|
||||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2.8.2)\n",
|
|
||||||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2024.2)\n",
|
|
||||||
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->llama-stack-client>=0.0.61->llama-stack) (2024.2)\n",
|
|
||||||
"Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.10/dist-packages (from tiktoken->llama-models>=0.0.61->llama-stack) (2024.9.11)\n",
|
|
||||||
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->llama-stack-client>=0.0.61->llama-stack) (1.17.0)\n",
|
|
||||||
"Installing pip dependencies\n",
|
|
||||||
"Requirement already satisfied: blobfile in /usr/local/lib/python3.10/dist-packages (3.0.0)\n",
|
|
||||||
"Requirement already satisfied: chardet in /usr/local/lib/python3.10/dist-packages (5.2.0)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-sdk in /usr/local/lib/python3.10/dist-packages (1.28.2)\n",
|
|
||||||
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.13.1)\n",
|
|
||||||
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (2.2.2)\n",
|
|
||||||
"Requirement already satisfied: autoevals in /usr/local/lib/python3.10/dist-packages (0.0.109)\n",
|
|
||||||
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.2.0)\n",
|
|
||||||
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.5.2)\n",
|
|
||||||
"Requirement already satisfied: pillow in /usr/local/lib/python3.10/dist-packages (10.4.0)\n",
|
|
||||||
"Requirement already satisfied: pypdf in /usr/local/lib/python3.10/dist-packages (5.1.0)\n",
|
|
||||||
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.6)\n",
|
|
||||||
"Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (3.9.1)\n",
|
|
||||||
"Requirement already satisfied: aiosqlite in /usr/local/lib/python3.10/dist-packages (0.20.0)\n",
|
|
||||||
"Requirement already satisfied: psycopg2-binary in /usr/local/lib/python3.10/dist-packages (2.9.10)\n",
|
|
||||||
"Requirement already satisfied: faiss-cpu in /usr/local/lib/python3.10/dist-packages (1.9.0.post1)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-exporter-otlp-proto-http in /usr/local/lib/python3.10/dist-packages (1.28.2)\n",
|
|
||||||
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.46.3)\n",
|
|
||||||
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.26.4)\n",
|
|
||||||
"Requirement already satisfied: chromadb-client in /usr/local/lib/python3.10/dist-packages (0.5.23)\n",
|
|
||||||
"Requirement already satisfied: openai in /usr/local/lib/python3.10/dist-packages (1.54.5)\n",
|
|
||||||
"Requirement already satisfied: redis in /usr/local/lib/python3.10/dist-packages (5.2.1)\n",
|
|
||||||
"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.2.0)\n",
|
|
||||||
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (3.8.0)\n",
|
|
||||||
"Requirement already satisfied: together in /usr/local/lib/python3.10/dist-packages (1.3.5)\n",
|
|
||||||
"Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (0.115.6)\n",
|
|
||||||
"Requirement already satisfied: fire in /usr/local/lib/python3.10/dist-packages (0.7.0)\n",
|
|
||||||
"Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (0.28.1)\n",
|
|
||||||
"Requirement already satisfied: uvicorn in /usr/local/lib/python3.10/dist-packages (0.32.1)\n",
|
|
||||||
"Requirement already satisfied: pycryptodomex>=3.8 in /usr/local/lib/python3.10/dist-packages (from blobfile) (3.21.0)\n",
|
|
||||||
"Requirement already satisfied: urllib3<3,>=1.25.3 in /usr/local/lib/python3.10/dist-packages (from blobfile) (2.2.3)\n",
|
|
||||||
"Requirement already satisfied: lxml>=4.9 in /usr/local/lib/python3.10/dist-packages (from blobfile) (5.3.0)\n",
|
|
||||||
"Requirement already satisfied: filelock>=3.0 in /usr/local/lib/python3.10/dist-packages (from blobfile) (3.16.1)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-api==1.28.2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-sdk) (1.28.2)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-semantic-conventions==0.49b2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-sdk) (0.49b2)\n",
|
|
||||||
"Requirement already satisfied: typing-extensions>=3.7.4 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-sdk) (4.12.2)\n",
|
|
||||||
"Requirement already satisfied: deprecated>=1.2.6 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-api==1.28.2->opentelemetry-sdk) (1.2.15)\n",
|
|
||||||
"Requirement already satisfied: importlib-metadata<=8.5.0,>=6.0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-api==1.28.2->opentelemetry-sdk) (8.5.0)\n",
|
|
||||||
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2)\n",
|
|
||||||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2024.2)\n",
|
|
||||||
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas) (2024.2)\n",
|
|
||||||
"Requirement already satisfied: chevron in /usr/local/lib/python3.10/dist-packages (from autoevals) (0.14.0)\n",
|
|
||||||
"Requirement already satisfied: levenshtein in /usr/local/lib/python3.10/dist-packages (from autoevals) (0.26.1)\n",
|
|
||||||
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from autoevals) (6.0.2)\n",
|
|
||||||
"Requirement already satisfied: braintrust_core==0.0.54 in /usr/local/lib/python3.10/dist-packages (from autoevals) (0.0.54)\n",
|
|
||||||
"Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from autoevals) (4.23.0)\n",
|
|
||||||
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.4.2)\n",
|
|
||||||
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.5.0)\n",
|
|
||||||
"Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk) (8.1.7)\n",
|
|
||||||
"Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk) (2024.9.11)\n",
|
|
||||||
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from faiss-cpu) (24.2)\n",
|
|
||||||
"Requirement already satisfied: googleapis-common-protos~=1.52 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-http) (1.66.0)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-exporter-otlp-proto-common==1.28.2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-http) (1.28.2)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-proto==1.28.2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-http) (1.28.2)\n",
|
|
||||||
"Requirement already satisfied: requests~=2.7 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-http) (2.32.3)\n",
|
|
||||||
"Requirement already satisfied: protobuf<6.0,>=5.0 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-proto==1.28.2->opentelemetry-exporter-otlp-proto-http) (5.29.1)\n",
|
|
||||||
"Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.26.5)\n",
|
|
||||||
"Requirement already satisfied: tokenizers<0.21,>=0.20 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.20.3)\n",
|
|
||||||
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n",
|
|
||||||
"Requirement already satisfied: opentelemetry-exporter-otlp-proto-grpc>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (1.28.2)\n",
|
|
||||||
"Requirement already satisfied: overrides>=7.3.1 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (7.7.0)\n",
|
|
||||||
"Requirement already satisfied: posthog>=2.4.0 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (3.7.4)\n",
|
|
||||||
"Requirement already satisfied: pydantic>=1.9 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (2.10.3)\n",
|
|
||||||
"Requirement already satisfied: tenacity>=8.2.3 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (9.0.0)\n",
|
|
||||||
"Requirement already satisfied: orjson>=3.9.12 in /usr/local/lib/python3.10/dist-packages (from chromadb-client) (3.10.12)\n",
|
|
||||||
"Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai) (3.7.1)\n",
|
|
||||||
"Requirement already satisfied: distro<2,>=1.7.0 in /usr/local/lib/python3.10/dist-packages (from openai) (1.9.0)\n",
|
|
||||||
"Requirement already satisfied: jiter<1,>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from openai) (0.8.2)\n",
|
|
||||||
"Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai) (1.3.1)\n",
|
|
||||||
"Requirement already satisfied: async-timeout>=4.0.3 in /usr/local/lib/python3.10/dist-packages (from redis) (4.0.3)\n",
|
|
||||||
"Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n",
|
|
||||||
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n",
|
|
||||||
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n",
|
|
||||||
"Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n",
|
|
||||||
"Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n",
|
|
||||||
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.10)\n",
|
|
||||||
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.3.1)\n",
|
|
||||||
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (0.12.1)\n",
|
|
||||||
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (4.55.2)\n",
|
|
||||||
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (1.4.7)\n",
|
|
||||||
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib) (3.2.0)\n",
|
|
||||||
"Requirement already satisfied: eval-type-backport<0.3.0,>=0.1.3 in /usr/local/lib/python3.10/dist-packages (from together) (0.2.0)\n",
|
|
||||||
"Requirement already satisfied: rich<14.0.0,>=13.8.1 in /usr/local/lib/python3.10/dist-packages (from together) (13.9.4)\n",
|
|
||||||
"Requirement already satisfied: tabulate<0.10.0,>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from together) (0.9.0)\n",
|
|
||||||
"Requirement already satisfied: typer<0.14,>=0.9 in /usr/local/lib/python3.10/dist-packages (from together) (0.13.1)\n",
|
|
||||||
"Requirement already satisfied: starlette<0.42.0,>=0.40.0 in /usr/local/lib/python3.10/dist-packages (from fastapi) (0.41.3)\n",
|
|
||||||
"Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from fire) (2.5.0)\n",
|
|
||||||
"Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx) (2024.8.30)\n",
|
|
||||||
"Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx) (1.0.7)\n",
|
|
||||||
"Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx) (3.10)\n",
|
|
||||||
"Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx) (0.14.0)\n",
|
|
||||||
"Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.4)\n",
|
|
||||||
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
|
|
||||||
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n",
|
|
||||||
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n",
|
|
||||||
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n",
|
|
||||||
"Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.1)\n",
|
|
||||||
"Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.18.3)\n",
|
|
||||||
"Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai) (1.2.2)\n",
|
|
||||||
"Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.10/dist-packages (from deprecated>=1.2.6->opentelemetry-api==1.28.2->opentelemetry-sdk) (1.17.0)\n",
|
|
||||||
"Requirement already satisfied: grpcio<2.0.0,>=1.63.2 in /usr/local/lib/python3.10/dist-packages (from opentelemetry-exporter-otlp-proto-grpc>=1.2.0->chromadb-client) (1.68.1)\n",
|
|
||||||
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from posthog>=2.4.0->chromadb-client) (1.17.0)\n",
|
|
||||||
"Requirement already satisfied: monotonic>=1.5 in /usr/local/lib/python3.10/dist-packages (from posthog>=2.4.0->chromadb-client) (1.6)\n",
|
|
||||||
"Requirement already satisfied: backoff>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from posthog>=2.4.0->chromadb-client) (2.2.1)\n",
|
|
||||||
"Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.9->chromadb-client) (0.7.0)\n",
|
|
||||||
"Requirement already satisfied: pydantic-core==2.27.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.9->chromadb-client) (2.27.1)\n",
|
|
||||||
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests~=2.7->opentelemetry-exporter-otlp-proto-http) (3.4.0)\n",
|
|
||||||
"Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich<14.0.0,>=13.8.1->together) (3.0.0)\n",
|
|
||||||
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich<14.0.0,>=13.8.1->together) (2.18.0)\n",
|
|
||||||
"Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from typer<0.14,>=0.9->together) (1.5.4)\n",
|
|
||||||
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->autoevals) (2024.10.1)\n",
|
|
||||||
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->autoevals) (0.35.1)\n",
|
|
||||||
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->autoevals) (0.22.3)\n",
|
|
||||||
"Requirement already satisfied: rapidfuzz<4.0.0,>=3.9.0 in /usr/local/lib/python3.10/dist-packages (from levenshtein->autoevals) (3.10.1)\n",
|
|
||||||
"Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.10/dist-packages (from importlib-metadata<=8.5.0,>=6.0->opentelemetry-api==1.28.2->opentelemetry-sdk) (3.21.0)\n",
|
|
||||||
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich<14.0.0,>=13.8.1->together) (0.1.2)\n",
|
|
||||||
"sentence-transformers --no-deps\n",
|
|
||||||
"Requirement already satisfied: sentence-transformers in /usr/local/lib/python3.10/dist-packages (3.2.1)\n",
|
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu\n",
|
|
||||||
"Looking in indexes: https://download.pytorch.org/whl/cpu\n",
|
|
||||||
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)\n",
|
|
||||||
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n",
|
|
||||||
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n",
|
|
||||||
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)\n",
|
|
||||||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n",
|
|
||||||
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.9.0)\n",
|
|
||||||
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n",
|
|
||||||
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
|
|
||||||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)\n",
|
|
||||||
"\u001b[32mBuild Successful!\u001b[0m\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# NBVAL_SKIP\n",
|
"# NBVAL_SKIP\n",
|
||||||
"!llama stack build --template together --image-type venv"
|
"!UV_SYSTEM_PYTHON=1 llama stack build --template together --image-type venv"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -1027,7 +781,7 @@
|
||||||
" benchmark_id=\"meta-reference::mmmu\",\n",
|
" benchmark_id=\"meta-reference::mmmu\",\n",
|
||||||
" input_rows=eval_rows,\n",
|
" input_rows=eval_rows,\n",
|
||||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
||||||
" task_config={\n",
|
" benchmark_config={\n",
|
||||||
" \"type\": \"benchmark\",\n",
|
" \"type\": \"benchmark\",\n",
|
||||||
" \"eval_candidate\": {\n",
|
" \"eval_candidate\": {\n",
|
||||||
" \"type\": \"model\",\n",
|
" \"type\": \"model\",\n",
|
||||||
|
|
@ -1206,7 +960,7 @@
|
||||||
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
||||||
" input_rows=eval_rows.rows,\n",
|
" input_rows=eval_rows.rows,\n",
|
||||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||||
" task_config={\n",
|
" benchmark_config={\n",
|
||||||
" \"type\": \"benchmark\",\n",
|
" \"type\": \"benchmark\",\n",
|
||||||
" \"eval_candidate\": {\n",
|
" \"eval_candidate\": {\n",
|
||||||
" \"type\": \"model\",\n",
|
" \"type\": \"model\",\n",
|
||||||
|
|
@ -1355,7 +1109,7 @@
|
||||||
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
||||||
" input_rows=eval_rows.rows,\n",
|
" input_rows=eval_rows.rows,\n",
|
||||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||||
" task_config={\n",
|
" benchmark_config={\n",
|
||||||
" \"type\": \"benchmark\",\n",
|
" \"type\": \"benchmark\",\n",
|
||||||
" \"eval_candidate\": {\n",
|
" \"eval_candidate\": {\n",
|
||||||
" \"type\": \"agent\",\n",
|
" \"type\": \"agent\",\n",
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ The RFC Specification (OpenAPI format) is generated from the set of API endpoint
|
||||||
Please install the following packages before running the script:
|
Please install the following packages before running the script:
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install fire PyYAML llama-models
|
pip install fire PyYAML
|
||||||
```
|
```
|
||||||
|
|
||||||
Then simply run `sh run_openapi_generator.sh`
|
Then simply run `sh run_openapi_generator.sh`
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,7 @@ def main(output_dir: str):
|
||||||
a set of endpoints and their corresponding interfaces that are tailored to
|
a set of endpoints and their corresponding interfaces that are tailored to
|
||||||
best leverage Llama Models.""",
|
best leverage Llama Models.""",
|
||||||
),
|
),
|
||||||
|
include_standard_error_responses=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import typing
|
||||||
from dataclasses import make_dataclass
|
from dataclasses import make_dataclass
|
||||||
from typing import Any, Dict, Set, Union
|
from typing import Any, Dict, Set, Union
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Error
|
||||||
from llama_stack.strong_typing.core import JsonType
|
from llama_stack.strong_typing.core import JsonType
|
||||||
from llama_stack.strong_typing.docstring import Docstring, parse_type
|
from llama_stack.strong_typing.docstring import Docstring, parse_type
|
||||||
from llama_stack.strong_typing.inspection import (
|
from llama_stack.strong_typing.inspection import (
|
||||||
|
|
@ -435,6 +436,75 @@ class Generator:
|
||||||
self.schema_builder = SchemaBuilder(schema_generator)
|
self.schema_builder = SchemaBuilder(schema_generator)
|
||||||
self.responses = {}
|
self.responses = {}
|
||||||
|
|
||||||
|
# Create standard error responses
|
||||||
|
self._create_standard_error_responses()
|
||||||
|
|
||||||
|
def _create_standard_error_responses(self) -> None:
|
||||||
|
"""
|
||||||
|
Creates standard error responses that can be reused across operations.
|
||||||
|
These will be added to the components.responses section of the OpenAPI document.
|
||||||
|
"""
|
||||||
|
# Get the Error schema
|
||||||
|
error_schema = self.schema_builder.classdef_to_ref(Error)
|
||||||
|
|
||||||
|
# Create standard error responses
|
||||||
|
self.responses["BadRequest400"] = Response(
|
||||||
|
description="The request was invalid or malformed",
|
||||||
|
content={
|
||||||
|
"application/json": MediaType(
|
||||||
|
schema=error_schema,
|
||||||
|
example={
|
||||||
|
"status": 400,
|
||||||
|
"title": "Bad Request",
|
||||||
|
"detail": "The request was invalid or malformed",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.responses["TooManyRequests429"] = Response(
|
||||||
|
description="The client has sent too many requests in a given amount of time",
|
||||||
|
content={
|
||||||
|
"application/json": MediaType(
|
||||||
|
schema=error_schema,
|
||||||
|
example={
|
||||||
|
"status": 429,
|
||||||
|
"title": "Too Many Requests",
|
||||||
|
"detail": "You have exceeded the rate limit. Please try again later.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.responses["InternalServerError500"] = Response(
|
||||||
|
description="The server encountered an unexpected error",
|
||||||
|
content={
|
||||||
|
"application/json": MediaType(
|
||||||
|
schema=error_schema,
|
||||||
|
example={
|
||||||
|
"status": 500,
|
||||||
|
"title": "Internal Server Error",
|
||||||
|
"detail": "An unexpected error occurred. Our team has been notified.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a default error response for any unhandled error cases
|
||||||
|
self.responses["DefaultError"] = Response(
|
||||||
|
description="An unexpected error occurred",
|
||||||
|
content={
|
||||||
|
"application/json": MediaType(
|
||||||
|
schema=error_schema,
|
||||||
|
example={
|
||||||
|
"status": 0,
|
||||||
|
"title": "Error",
|
||||||
|
"detail": "An unexpected error occurred",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
|
def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
|
||||||
# Don't include schema definition in the tag description because for one,
|
# Don't include schema definition in the tag description because for one,
|
||||||
# it is not very valuable and for another, it causes string formatting
|
# it is not very valuable and for another, it causes string formatting
|
||||||
|
|
@ -649,6 +719,18 @@ class Generator:
|
||||||
responses.update(response_builder.build_response(response_options))
|
responses.update(response_builder.build_response(response_options))
|
||||||
|
|
||||||
assert len(responses.keys()) > 0, f"No responses found for {op.name}"
|
assert len(responses.keys()) > 0, f"No responses found for {op.name}"
|
||||||
|
|
||||||
|
# Add standard error response references
|
||||||
|
if self.options.include_standard_error_responses:
|
||||||
|
if "400" not in responses:
|
||||||
|
responses["400"] = ResponseRef("BadRequest400")
|
||||||
|
if "429" not in responses:
|
||||||
|
responses["429"] = ResponseRef("TooManyRequests429")
|
||||||
|
if "500" not in responses:
|
||||||
|
responses["500"] = ResponseRef("InternalServerError500")
|
||||||
|
if "default" not in responses:
|
||||||
|
responses["default"] = ResponseRef("DefaultError")
|
||||||
|
|
||||||
if op.event_type is not None:
|
if op.event_type is not None:
|
||||||
builder = ContentBuilder(self.schema_builder)
|
builder = ContentBuilder(self.schema_builder)
|
||||||
callbacks = {
|
callbacks = {
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ class Options:
|
||||||
:param error_wrapper: True if errors are encapsulated in an error object wrapper.
|
:param error_wrapper: True if errors are encapsulated in an error object wrapper.
|
||||||
:param property_description_fun: Custom transformation function to apply to class property documentation strings.
|
:param property_description_fun: Custom transformation function to apply to class property documentation strings.
|
||||||
:param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types.
|
:param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types.
|
||||||
|
:param include_standard_error_responses: Whether to include standard error responses (400, 429, 500, 503) in all operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
server: Server
|
server: Server
|
||||||
|
|
@ -52,6 +53,7 @@ class Options:
|
||||||
error_wrapper: bool = False
|
error_wrapper: bool = False
|
||||||
property_description_fun: Optional[Callable[[type, str, str], str]] = None
|
property_description_fun: Optional[Callable[[type, str, str], str]] = None
|
||||||
captions: Optional[Dict[str, str]] = None
|
captions: Optional[Dict[str, str]] = None
|
||||||
|
include_standard_error_responses: bool = True
|
||||||
|
|
||||||
default_captions: ClassVar[Dict[str, str]] = {
|
default_captions: ClassVar[Dict[str, str]] = {
|
||||||
"Operations": "Operations",
|
"Operations": "Operations",
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,5 @@ if [ ${#missing_packages[@]} -ne 0 ]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
stack_dir=$(dirname $(dirname $THIS_DIR))
|
stack_dir=$(dirname $(dirname $THIS_DIR))
|
||||||
models_dir=$(dirname $stack_dir)/llama-models
|
PYTHONPATH=$PYTHONPATH:$stack_dir \
|
||||||
PYTHONPATH=$PYTHONPATH:$stack_dir:$models_dir \
|
|
||||||
python -m docs.openapi_generator.generate $(dirname $THIS_DIR)/_static
|
python -m docs.openapi_generator.generate $(dirname $THIS_DIR)/_static
|
||||||
|
|
|
||||||
|
|
@ -11,3 +11,4 @@ sphinxcontrib-openapi
|
||||||
sphinxcontrib-redoc
|
sphinxcontrib-redoc
|
||||||
sphinxcontrib-mermaid
|
sphinxcontrib-mermaid
|
||||||
sphinxcontrib-video
|
sphinxcontrib-video
|
||||||
|
tomli
|
||||||
|
|
|
||||||
91
docs/source/building_applications/agent.md
Normal file
91
docs/source/building_applications/agent.md
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
# Llama Stack Agent Framework
|
||||||
|
|
||||||
|
The Llama Stack agent framework is built on a modular architecture that allows for flexible and powerful AI applications. This document explains the key components and how they work together.
|
||||||
|
|
||||||
|
## Core Concepts
|
||||||
|
|
||||||
|
### 1. Agent Configuration
|
||||||
|
|
||||||
|
Agents are configured using the `AgentConfig` class, which includes:
|
||||||
|
|
||||||
|
- **Model**: The underlying LLM to power the agent
|
||||||
|
- **Instructions**: System prompt that defines the agent's behavior
|
||||||
|
- **Tools**: Capabilities the agent can use to interact with external systems
|
||||||
|
- **Safety Shields**: Guardrails to ensure responsible AI behavior
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
|
|
||||||
|
# Configure an agent
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model="meta-llama/Llama-3-70b-chat",
|
||||||
|
instructions="You are a helpful assistant that can use tools to answer questions.",
|
||||||
|
toolgroups=["builtin::code_interpreter", "builtin::rag/knowledge_search"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the agent
|
||||||
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Sessions
|
||||||
|
|
||||||
|
Agents maintain state through sessions, which represent a conversation thread:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create a session
|
||||||
|
session_id = agent.create_session(session_name="My conversation")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Turns
|
||||||
|
|
||||||
|
Each interaction with an agent is called a "turn" and consists of:
|
||||||
|
|
||||||
|
- **Input Messages**: What the user sends to the agent
|
||||||
|
- **Steps**: The agent's internal processing (inference, tool execution, etc.)
|
||||||
|
- **Output Message**: The agent's response
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
|
|
||||||
|
# Create a turn with streaming response
|
||||||
|
turn_response = agent.create_turn(
|
||||||
|
session_id=session_id,
|
||||||
|
messages=[{"role": "user", "content": "Tell me about Llama models"}],
|
||||||
|
)
|
||||||
|
for log in EventLogger().log(turn_response):
|
||||||
|
log.print()
|
||||||
|
```
|
||||||
|
### Non-Streaming
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
# Non-streaming API
|
||||||
|
response = agent.create_turn(
|
||||||
|
session_id=session_id,
|
||||||
|
messages=[{"role": "user", "content": "Tell me about Llama models"}],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
print("Inputs:")
|
||||||
|
pprint(response.input_messages)
|
||||||
|
print("Output:")
|
||||||
|
pprint(response.output_message.content)
|
||||||
|
print("Steps:")
|
||||||
|
pprint(response.steps)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Steps
|
||||||
|
|
||||||
|
Each turn consists of multiple steps that represent the agent's thought process:
|
||||||
|
|
||||||
|
- **Inference Steps**: The agent generating text responses
|
||||||
|
- **Tool Execution Steps**: The agent using tools to gather information
|
||||||
|
- **Shield Call Steps**: Safety checks being performed
|
||||||
|
|
||||||
|
## Agent Execution Loop
|
||||||
|
|
||||||
|
|
||||||
|
Refer to the [Agent Execution Loop](agent_execution_loop) for more details on what happens within an agent turn.
|
||||||
|
|
@ -7,13 +7,13 @@ Each agent turn follows these key steps:
|
||||||
1. **Initial Safety Check**: The user's input is first screened through configured safety shields
|
1. **Initial Safety Check**: The user's input is first screened through configured safety shields
|
||||||
|
|
||||||
2. **Context Retrieval**:
|
2. **Context Retrieval**:
|
||||||
- If RAG is enabled, the agent queries relevant documents from memory banks
|
- If RAG is enabled, the agent can choose to query relevant documents from memory banks. You can use the `instructions` field to steer the agent.
|
||||||
- For new documents, they are first inserted into the memory bank
|
- For new documents, they are first inserted into the memory bank.
|
||||||
- Retrieved context is augmented to the user's prompt
|
- Retrieved context is provided to the LLM as a tool response in the message history.
|
||||||
|
|
||||||
3. **Inference Loop**: The agent enters its main execution loop:
|
3. **Inference Loop**: The agent enters its main execution loop:
|
||||||
- The LLM receives the augmented prompt (with context and/or previous tool outputs)
|
- The LLM receives a user prompt (with previous tool outputs)
|
||||||
- The LLM generates a response, potentially with tool calls
|
- The LLM generates a response, potentially with [tool calls](tools)
|
||||||
- If tool calls are present:
|
- If tool calls are present:
|
||||||
- Tool inputs are safety-checked
|
- Tool inputs are safety-checked
|
||||||
- Tools are executed (e.g., web search, code execution)
|
- Tools are executed (e.g., web search, code execution)
|
||||||
|
|
@ -40,19 +40,16 @@ sequenceDiagram
|
||||||
S->>E: Input Safety Check
|
S->>E: Input Safety Check
|
||||||
deactivate S
|
deactivate S
|
||||||
|
|
||||||
E->>M: 2.1 Query Context
|
|
||||||
M-->>E: 2.2 Retrieved Documents
|
|
||||||
|
|
||||||
loop Inference Loop
|
loop Inference Loop
|
||||||
E->>L: 3.1 Augment with Context
|
E->>L: 2.1 Augment with Context
|
||||||
L-->>E: 3.2 Response (with/without tool calls)
|
L-->>E: 2.2 Response (with/without tool calls)
|
||||||
|
|
||||||
alt Has Tool Calls
|
alt Has Tool Calls
|
||||||
E->>S: Check Tool Input
|
E->>S: Check Tool Input
|
||||||
S->>T: 4.1 Execute Tool
|
S->>T: 3.1 Execute Tool
|
||||||
T-->>E: 4.2 Tool Response
|
T-->>E: 3.2 Tool Response
|
||||||
E->>L: 5.1 Tool Response
|
E->>L: 4.1 Tool Response
|
||||||
L-->>E: 5.2 Synthesized Response
|
L-->>E: 4.2 Synthesized Response
|
||||||
end
|
end
|
||||||
|
|
||||||
opt Stop Conditions
|
opt Stop Conditions
|
||||||
|
|
@ -64,23 +61,34 @@ sequenceDiagram
|
||||||
end
|
end
|
||||||
|
|
||||||
E->>S: Output Safety Check
|
E->>S: Output Safety Check
|
||||||
S->>U: 6. Final Response
|
S->>U: 5. Final Response
|
||||||
```
|
```
|
||||||
|
|
||||||
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
|
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
# Replace host and port
|
||||||
|
client = LlamaStackClient(base_url=f"http://{HOST}:{PORT}")
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
|
# Check with `llama-stack-client models list`
|
||||||
model="Llama3.2-3B-Instruct",
|
model="Llama3.2-3B-Instruct",
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
# Enable both RAG and tool usage
|
# Enable both RAG and tool usage
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
{"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}},
|
{
|
||||||
|
"name": "builtin::rag/knowledge_search",
|
||||||
|
"args": {"vector_db_ids": ["my_docs"]},
|
||||||
|
},
|
||||||
"builtin::code_interpreter",
|
"builtin::code_interpreter",
|
||||||
],
|
],
|
||||||
# Configure safety
|
# Configure safety (optional)
|
||||||
input_shields=["llama_guard"],
|
input_shields=["llama_guard"],
|
||||||
output_shields=["llama_guard"],
|
output_shields=["llama_guard"],
|
||||||
# Control the inference loop
|
# Control the inference loop
|
||||||
|
|
@ -97,7 +105,7 @@ session_id = agent.create_session("monitored_session")
|
||||||
# Stream the agent's execution steps
|
# Stream the agent's execution steps
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
messages=[{"role": "user", "content": "Analyze this code and run it"}],
|
messages=[{"role": "user", "content": "Analyze this code and run it"}],
|
||||||
attachments=[
|
documents=[
|
||||||
{
|
{
|
||||||
"content": "https://raw.githubusercontent.com/example/code.py",
|
"content": "https://raw.githubusercontent.com/example/code.py",
|
||||||
"mime_type": "text/plain",
|
"mime_type": "text/plain",
|
||||||
|
|
@ -108,14 +116,21 @@ response = agent.create_turn(
|
||||||
|
|
||||||
# Monitor each step of execution
|
# Monitor each step of execution
|
||||||
for log in EventLogger().log(response):
|
for log in EventLogger().log(response):
|
||||||
if log.event.step_type == "memory_retrieval":
|
log.print()
|
||||||
print("Retrieved context:", log.event.retrieved_context)
|
|
||||||
elif log.event.step_type == "inference":
|
# Using non-streaming API, the response contains input, steps, and output.
|
||||||
print("LLM output:", log.event.model_response)
|
response = agent.create_turn(
|
||||||
elif log.event.step_type == "tool_execution":
|
messages=[{"role": "user", "content": "Analyze this code and run it"}],
|
||||||
print("Tool call:", log.event.tool_call)
|
documents=[
|
||||||
print("Tool response:", log.event.tool_response)
|
{
|
||||||
elif log.event.step_type == "shield_call":
|
"content": "https://raw.githubusercontent.com/example/code.py",
|
||||||
if log.event.violation:
|
"mime_type": "text/plain",
|
||||||
print("Safety violation:", log.event.violation)
|
}
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
pprint(f"Input: {response.input_messages}")
|
||||||
|
pprint(f"Output: {response.output_message.content}")
|
||||||
|
pprint(f"Steps: {response.steps}")
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ response = client.eval.evaluate_rows(
|
||||||
benchmark_id="meta-reference::mmmu",
|
benchmark_id="meta-reference::mmmu",
|
||||||
input_rows=eval_rows,
|
input_rows=eval_rows,
|
||||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "model",
|
"type": "model",
|
||||||
|
|
@ -109,7 +109,7 @@ response = client.eval.evaluate_rows(
|
||||||
benchmark_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
input_rows=eval_rows.rows,
|
input_rows=eval_rows.rows,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "model",
|
"type": "model",
|
||||||
|
|
@ -149,7 +149,6 @@ agent_config = {
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"tool_choice": "auto",
|
"tool_choice": "auto",
|
||||||
"tool_prompt_format": "json",
|
|
||||||
"input_shields": [],
|
"input_shields": [],
|
||||||
"output_shields": [],
|
"output_shields": [],
|
||||||
"enable_session_persistence": False,
|
"enable_session_persistence": False,
|
||||||
|
|
@ -159,7 +158,7 @@ response = client.eval.evaluate_rows(
|
||||||
benchmark_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
input_rows=eval_rows.rows,
|
input_rows=eval_rows.rows,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "agent",
|
"type": "agent",
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ response = client.benchmarks.register(
|
||||||
# Run evaluation
|
# Run evaluation
|
||||||
job = client.eval.run_eval(
|
job = client.eval.run_eval(
|
||||||
benchmark_id="my_eval",
|
benchmark_id="my_eval",
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "app",
|
"type": "app",
|
||||||
"eval_candidate": {"type": "agent", "config": agent_config},
|
"eval_candidate": {"type": "agent", "config": agent_config},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -8,22 +8,24 @@ The best way to get started is to look at this notebook which walks through the
|
||||||
|
|
||||||
Here are some key topics that will help you build effective agents:
|
Here are some key topics that will help you build effective agents:
|
||||||
|
|
||||||
- **[Agent Execution Loop](agent_execution_loop)**
|
- **[Agent](agent)**: Understand the components and design patterns of the Llama Stack agent framework.
|
||||||
- **[RAG](rag)**
|
- **[Agent Execution Loop](agent_execution_loop)**: Understand how agents process information, make decisions, and execute actions in a continuous loop.
|
||||||
- **[Safety](safety)**
|
- **[RAG (Retrieval-Augmented Generation)](rag)**: Learn how to enhance your agents with external knowledge through retrieval mechanisms.
|
||||||
- **[Tools](tools)**
|
- **[Tools](tools)**: Extend your agents' capabilities by integrating with external tools and APIs.
|
||||||
- **[Telemetry](telemetry)**
|
- **[Evals](evals)**: Evaluate your agents' effectiveness and identify areas for improvement.
|
||||||
- **[Evals](evals)**
|
- **[Telemetry](telemetry)**: Monitor and analyze your agents' performance and behavior.
|
||||||
|
- **[Safety](safety)**: Implement guardrails and safety measures to ensure responsible AI behavior.
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:hidden:
|
:hidden:
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
|
agent
|
||||||
agent_execution_loop
|
agent_execution_loop
|
||||||
rag
|
rag
|
||||||
safety
|
|
||||||
tools
|
tools
|
||||||
telemetry
|
telemetry
|
||||||
evals
|
evals
|
||||||
|
advanced_agent_patterns
|
||||||
|
safety
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
## Using "Memory" or Retrieval Augmented Generation (RAG)
|
## Using Retrieval Augmented Generation (RAG)
|
||||||
|
|
||||||
Memory enables your applications to reference and recall information from previous interactions or external documents.
|
RAG enables your applications to reference and recall information from previous interactions or external documents.
|
||||||
|
|
||||||
Llama Stack organizes the memory APIs into three layers:
|
Llama Stack organizes the APIs that enable RAG into three layers:
|
||||||
- the lowermost APIs deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon.)
|
- the lowermost APIs deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon.)
|
||||||
- next is the "Rag Tool", a first-class tool as part of the Tools API that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly.
|
- next is the "Rag Tool", a first-class tool as part of the Tools API that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly.
|
||||||
- finally, it all comes together with the top-level "Agents" API that allows you to create agents that can use the tools to answer questions, perform tasks, and more.
|
- finally, it all comes together with the top-level "Agents" API that allows you to create agents that can use the tools to answer questions, perform tasks, and more.
|
||||||
|
|
@ -86,12 +86,12 @@ from llama_stack_client.lib.agents.agent import Agent
|
||||||
|
|
||||||
# Configure agent with memory
|
# Configure agent with memory
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
model="meta-llama/Llama-3.3-70B-Instruct",
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
enable_session_persistence=False,
|
enable_session_persistence=False,
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
{
|
{
|
||||||
"name": "builtin::rag",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {
|
"args": {
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
},
|
},
|
||||||
|
|
@ -102,6 +102,19 @@ agent_config = AgentConfig(
|
||||||
agent = Agent(client, agent_config)
|
agent = Agent(client, agent_config)
|
||||||
session_id = agent.create_session("rag_session")
|
session_id = agent.create_session("rag_session")
|
||||||
|
|
||||||
|
|
||||||
|
# Ask questions about documents in the vector db, and the agent will query the db to answer the question.
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[{"role": "user", "content": "How to optimize memory in PyTorch?"}],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
> **NOTE:** the `instructions` field in the `AgentConfig` can be used to guide the agent's behavior. It is important to experiment with different instructions to see what works best for your use case.
|
||||||
|
|
||||||
|
|
||||||
|
You can also pass documents along with the user's message and ask questions about them.
|
||||||
|
```python
|
||||||
# Initial document ingestion
|
# Initial document ingestion
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
messages=[
|
messages=[
|
||||||
|
|
|
||||||
|
|
@ -83,15 +83,15 @@ result = client.tool_runtime.invoke_tool(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Memory
|
#### RAG
|
||||||
|
|
||||||
The Memory tool enables retrieval of context from various types of memory banks (vector, key-value, keyword, and graph).
|
The RAG tool enables retrieval of context from various types of memory banks (vector, key-value, keyword, and graph).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Register Memory tool group
|
# Register Memory tool group
|
||||||
client.toolgroups.register(
|
client.toolgroups.register(
|
||||||
toolgroup_id="builtin::memory",
|
toolgroup_id="builtin::rag",
|
||||||
provider_id="memory",
|
provider_id="faiss",
|
||||||
args={"max_chunks": 5, "max_tokens_in_context": 4096},
|
args={"max_chunks": 5, "max_tokens_in_context": 4096},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
@ -102,7 +102,7 @@ Features:
|
||||||
- Context retrieval with token limits
|
- Context retrieval with token limits
|
||||||
|
|
||||||
|
|
||||||
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, code interpreter and memory, that are provided by tavily-search, code-interpreter and memory providers.
|
> **Note:** By default, llama stack run.yaml defines toolgroups for web search, code interpreter and rag, that are provided by tavily-search, code-interpreter and rag providers.
|
||||||
|
|
||||||
## Model Context Protocol (MCP) Tools
|
## Model Context Protocol (MCP) Tools
|
||||||
|
|
||||||
|
|
@ -125,50 +125,43 @@ MCP tools require:
|
||||||
- Tools are discovered dynamically from the endpoint
|
- Tools are discovered dynamically from the endpoint
|
||||||
|
|
||||||
|
|
||||||
## Tools provided by the client
|
## Adding Custom Tools
|
||||||
|
|
||||||
These tools are registered along with the agent config and are specific to the agent for which they are registered. The main difference between these tools and the tools provided by the built-in providers is that the execution of these tools is handled by the client and the agent transfers the tool call to the client and waits for the result from the client.
|
When you want to use tools other than the built-in tools, you can implement a python function and decorate it with `@client_tool`.
|
||||||
|
|
||||||
|
To define a custom tool, you need to use the `@client_tool` decorator.
|
||||||
|
```python
|
||||||
|
from llama_stack_client.lib.agents.client_tool import client_tool
|
||||||
|
|
||||||
|
|
||||||
|
# Example tool definition
|
||||||
|
@client_tool
|
||||||
|
def my_tool(input: int) -> int:
|
||||||
|
"""
|
||||||
|
Runs my awesome tool.
|
||||||
|
|
||||||
|
:param input: some int parameter
|
||||||
|
"""
|
||||||
|
return input * 2
|
||||||
|
```
|
||||||
|
> **NOTE:** We employ python docstrings to describe the tool and the parameters. It is important to document the tool and the parameters so that the model can use the tool correctly. It is recommended to experiment with different docstrings to see how they affect the model's behavior.
|
||||||
|
|
||||||
|
Once defined, simply pass the tool to the agent config. `Agent` will take care of the rest (calling the model with the tool definition, executing the tool, and returning the result to the model for the next iteration).
|
||||||
```python
|
```python
|
||||||
# Example agent config with client provided tools
|
# Example agent config with client provided tools
|
||||||
config = AgentConfig(
|
client_tools = [
|
||||||
toolgroups=[
|
my_tool,
|
||||||
"builtin::websearch",
|
]
|
||||||
],
|
|
||||||
client_tools=[ToolDef(name="client_tool", description="Client provided tool")],
|
agent_config = AgentConfig(
|
||||||
|
...,
|
||||||
|
client_tools=[client_tool.get_tool_definition() for client_tool in client_tools],
|
||||||
)
|
)
|
||||||
|
agent = Agent(client, agent_config, client_tools)
|
||||||
```
|
```
|
||||||
|
|
||||||
Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools.
|
Refer to [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/e2e_loop_with_client_tools.py) for an example of how to use client provided tools.
|
||||||
|
|
||||||
## Tool Structure
|
|
||||||
|
|
||||||
Each tool has the following components:
|
|
||||||
|
|
||||||
- `name`: Unique identifier for the tool
|
|
||||||
- `description`: Human-readable description of the tool's functionality
|
|
||||||
- `parameters`: List of parameters the tool accepts
|
|
||||||
- `name`: Parameter name
|
|
||||||
- `parameter_type`: Data type (string, number, etc.)
|
|
||||||
- `description`: Parameter description
|
|
||||||
- `required`: Whether the parameter is required (default: true)
|
|
||||||
- `default`: Default value if any
|
|
||||||
|
|
||||||
Example tool definition:
|
|
||||||
```python
|
|
||||||
{
|
|
||||||
"name": "web_search",
|
|
||||||
"description": "Search the web for information",
|
|
||||||
"parameters": [
|
|
||||||
{
|
|
||||||
"name": "query",
|
|
||||||
"parameter_type": "string",
|
|
||||||
"description": "The query to search for",
|
|
||||||
"required": True,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Tool Invocation
|
## Tool Invocation
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,17 @@
|
||||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||||
|
|
||||||
from docutils import nodes
|
from docutils import nodes
|
||||||
|
import tomli # Import tomli for TOML parsing
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Read version from pyproject.toml
|
||||||
|
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
||||||
|
pyproject = tomli.load(f)
|
||||||
|
llama_stack_version = pyproject["project"]["version"]
|
||||||
|
|
||||||
|
# generate the full link including text and url here
|
||||||
|
llama_stack_version_url = f"https://github.com/meta-llama/llama-stack/releases/tag/v{llama_stack_version}"
|
||||||
|
llama_stack_version_link = f"<a href='{llama_stack_version_url}'>release notes</a>"
|
||||||
|
|
||||||
project = "llama-stack"
|
project = "llama-stack"
|
||||||
copyright = "2025, Meta"
|
copyright = "2025, Meta"
|
||||||
|
|
@ -66,6 +77,8 @@ myst_enable_extensions = [
|
||||||
|
|
||||||
myst_substitutions = {
|
myst_substitutions = {
|
||||||
"docker_hub": "https://hub.docker.com/repository/docker/llamastack",
|
"docker_hub": "https://hub.docker.com/repository/docker/llamastack",
|
||||||
|
"llama_stack_version": llama_stack_version,
|
||||||
|
"llama_stack_version_link": llama_stack_version_link,
|
||||||
}
|
}
|
||||||
|
|
||||||
suppress_warnings = ['myst.header']
|
suppress_warnings = ['myst.header']
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ Here are some example PRs to help you get started:
|
||||||
- Create integration tests that use real provider instances and configurations
|
- Create integration tests that use real provider instances and configurations
|
||||||
- For remote services, test actual API interactions
|
- For remote services, test actual API interactions
|
||||||
- Avoid mocking at the provider level since adapter layers tend to be thin
|
- Avoid mocking at the provider level since adapter layers tend to be thin
|
||||||
- Reference examples in {repopath}`tests/client-sdk`
|
- Reference examples in {repopath}`tests/api`
|
||||||
|
|
||||||
### 2. Unit Testing (Optional)
|
### 2. Unit Testing (Optional)
|
||||||
- Add unit tests for provider-specific functionality
|
- Add unit tests for provider-specific functionality
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ Build a Llama stack container
|
||||||
|
|
||||||
options:
|
options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml.
|
--config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml.
|
||||||
If this argument is not provided, you will be prompted to enter information interactively
|
If this argument is not provided, you will be prompted to enter information interactively
|
||||||
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates
|
--template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates
|
||||||
--list-templates Show the available templates for building a Llama Stack distribution
|
--list-templates Show the available templates for building a Llama Stack distribution
|
||||||
|
|
@ -106,7 +106,7 @@ It would be best to start with a template and understand the structure of the co
|
||||||
llama stack build
|
llama stack build
|
||||||
|
|
||||||
> Enter a name for your Llama Stack (e.g. my-local-stack): my-stack
|
> Enter a name for your Llama Stack (e.g. my-local-stack): my-stack
|
||||||
> Enter the image type you want your Llama Stack to be built as (container or conda): conda
|
> Enter the image type you want your Llama Stack to be built as (container or conda or venv): conda
|
||||||
|
|
||||||
Llama Stack is composed of several APIs working together. Let's select
|
Llama Stack is composed of several APIs working together. Let's select
|
||||||
the provider types (implementations) you want to use for these APIs.
|
the provider types (implementations) you want to use for these APIs.
|
||||||
|
|
@ -187,14 +187,14 @@ usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-i
|
||||||
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
|
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
|
||||||
config
|
config
|
||||||
|
|
||||||
start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
|
Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
|
||||||
|
|
||||||
positional arguments:
|
positional arguments:
|
||||||
config Path to config file to use for the run
|
config Path to config file to use for the run
|
||||||
|
|
||||||
options:
|
options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
--port PORT Port to run the server on. Defaults to 8321
|
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. Defaults to 8321
|
||||||
--image-name IMAGE_NAME
|
--image-name IMAGE_NAME
|
||||||
Name of the image to run. Defaults to the current conda environment
|
Name of the image to run. Defaults to the current conda environment
|
||||||
--disable-ipv6 Disable IPv6 support
|
--disable-ipv6 Disable IPv6 support
|
||||||
|
|
|
||||||
|
|
@ -27,16 +27,19 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3-8B-Instruct (meta/llama3-8b-instruct)`
|
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3-70B-Instruct (meta/llama3-70b-instruct)`
|
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (meta/llama-3.1-8b-instruct)`
|
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-70B-Instruct (meta/llama-3.1-70b-instruct)`
|
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (meta/llama-3.1-405b-instruct)`
|
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
- `meta-llama/Llama-3.2-1B-Instruct (meta/llama-3.2-1b-instruct)`
|
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-3B-Instruct (meta/llama-3.2-3b-instruct)`
|
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct (meta/llama-3.2-11b-vision-instruct)`
|
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct (meta/llama-3.2-90b-vision-instruct)`
|
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
- `baai/bge-m3 (baai/bge-m3)`
|
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
||||||
|
- `nvidia/nv-embedqa-e5-v5 `
|
||||||
|
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
||||||
|
- `snowflake/arctic-embed-l `
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
||||||
|
|
@ -34,9 +34,9 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (meta.llama3-1-8b-instruct-v1:0)`
|
- `meta.llama3-1-8b-instruct-v1:0 (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-70B-Instruct (meta.llama3-1-70b-instruct-v1:0)`
|
- `meta.llama3-1-70b-instruct-v1:0 (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (meta.llama3-1-405b-instruct-v1:0)`
|
- `meta.llama3-1-405b-instruct-v1:0 (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
||||||
|
|
@ -27,8 +27,8 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (llama3.1-8b)`
|
- `llama3.1-8b (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.3-70B-Instruct (llama-3.3-70b)`
|
- `llama-3.3-70b (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
||||||
|
|
@ -37,17 +37,17 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (accounts/fireworks/models/llama-v3p1-8b-instruct)`
|
- `accounts/fireworks/models/llama-v3p1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-70B-Instruct (accounts/fireworks/models/llama-v3p1-70b-instruct)`
|
- `accounts/fireworks/models/llama-v3p1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (accounts/fireworks/models/llama-v3p1-405b-instruct)`
|
- `accounts/fireworks/models/llama-v3p1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
- `meta-llama/Llama-3.2-1B-Instruct (accounts/fireworks/models/llama-v3p2-1b-instruct)`
|
- `accounts/fireworks/models/llama-v3p2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-3B-Instruct (accounts/fireworks/models/llama-v3p2-3b-instruct)`
|
- `accounts/fireworks/models/llama-v3p2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct (accounts/fireworks/models/llama-v3p2-11b-vision-instruct)`
|
- `accounts/fireworks/models/llama-v3p2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct (accounts/fireworks/models/llama-v3p2-90b-vision-instruct)`
|
- `accounts/fireworks/models/llama-v3p2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-3.3-70B-Instruct (accounts/fireworks/models/llama-v3p3-70b-instruct)`
|
- `accounts/fireworks/models/llama-v3p3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `meta-llama/Llama-Guard-3-8B (accounts/fireworks/models/llama-guard-3-8b)`
|
- `accounts/fireworks/models/llama-guard-3-8b (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||||
- `meta-llama/Llama-Guard-3-11B-Vision (accounts/fireworks/models/llama-guard-3-11b-vision)`
|
- `accounts/fireworks/models/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||||
- `nomic-ai/nomic-embed-text-v1.5 (nomic-ai/nomic-embed-text-v1.5)`
|
- `nomic-ai/nomic-embed-text-v1.5 `
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
||||||
77
docs/source/distributions/self_hosted_distro/groq.md
Normal file
77
docs/source/distributions/self_hosted_distro/groq.md
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||||
|
# Groq Distribution
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 2
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
self
|
||||||
|
```
|
||||||
|
|
||||||
|
The `llamastack/distribution-groq` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
| API | Provider(s) |
|
||||||
|
|-----|-------------|
|
||||||
|
| agents | `inline::meta-reference` |
|
||||||
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
|
| eval | `inline::meta-reference` |
|
||||||
|
| inference | `remote::groq` |
|
||||||
|
| safety | `inline::llama-guard` |
|
||||||
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
||||||
|
| vector_io | `inline::faiss` |
|
||||||
|
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
|
- `GROQ_API_KEY`: Groq API Key (default: ``)
|
||||||
|
|
||||||
|
### Models
|
||||||
|
|
||||||
|
The following models are available by default:
|
||||||
|
|
||||||
|
- `groq/llama3-8b-8192 (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
|
- `groq/llama-3.1-8b-instant `
|
||||||
|
- `groq/llama3-70b-8192 (aliases: meta-llama/Llama-3-70B-Instruct)`
|
||||||
|
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
|
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
|
|
||||||
|
|
||||||
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
|
Make sure you have access to a Groq API Key. You can get one by visiting [Groq](https://api.groq.com/).
|
||||||
|
|
||||||
|
|
||||||
|
## Running Llama Stack with Groq
|
||||||
|
|
||||||
|
You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_STACK_PORT=5001
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
llamastack/distribution-groq \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env GROQ_API_KEY=$GROQ_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via Conda
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template groq --image-type conda
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env GROQ_API_KEY=$GROQ_API_KEY
|
||||||
|
```
|
||||||
|
|
@ -41,12 +41,31 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
## Prerequisite: Downloading Models
|
## Prerequisite: Downloading Models
|
||||||
|
|
||||||
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ls ~/.llama/checkpoints
|
$ llama model list --downloaded
|
||||||
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B
|
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
||||||
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M
|
┃ Model ┃ Size ┃ Modified Time ┃
|
||||||
|
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
||||||
|
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
||||||
|
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
## Running the Distribution
|
## Running the Distribution
|
||||||
|
|
|
||||||
|
|
@ -41,12 +41,31 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
## Prerequisite: Downloading Models
|
## Prerequisite: Downloading Models
|
||||||
|
|
||||||
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ls ~/.llama/checkpoints
|
$ llama model list --downloaded
|
||||||
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B
|
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
||||||
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M
|
┃ Model ┃ Size ┃ Modified Time ┃
|
||||||
|
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
||||||
|
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
||||||
|
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
## Running the Distribution
|
## Running the Distribution
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -141,17 +141,21 @@ ollama run <model_name>
|
||||||
To make sure that the model is being served correctly, run `ollama ps` to get a list of models being served by ollama.
|
To make sure that the model is being served correctly, run `ollama ps` to get a list of models being served by ollama.
|
||||||
```
|
```
|
||||||
$ ollama ps
|
$ ollama ps
|
||||||
|
NAME ID SIZE PROCESSOR UNTIL
|
||||||
NAME ID SIZE PROCESSOR UNTIL
|
llama3.2:3b-instruct-fp16 195a8c01d91e 8.6 GB 100% GPU 9 minutes from now
|
||||||
llama3.1:8b-instruct-fp16 4aacac419454 17 GB 100% GPU 4 minutes from now
|
|
||||||
```
|
```
|
||||||
|
|
||||||
To verify that the model served by ollama is correctly connected to Llama Stack server
|
To verify that the model served by ollama is correctly connected to Llama Stack server
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client models list
|
$ llama-stack-client models list
|
||||||
+----------------------+----------------------+---------------+-----------------------------------------------+
|
|
||||||
| identifier | llama_model | provider_id | metadata |
|
Available Models
|
||||||
+======================+======================+===============+===============================================+
|
|
||||||
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | ollama0 | {'ollama_model': 'llama3.1:8b-instruct-fp16'} |
|
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┓
|
||||||
+----------------------+----------------------+---------------+-----------------------------------------------+
|
┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
|
||||||
|
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━┩
|
||||||
|
│ llm │ meta-llama/Llama-3.2-3B-Instruct │ llama3.2:3b-instruct-fp16 │ │ ollama │
|
||||||
|
└──────────────┴──────────────────────────────────────┴──────────────────────────────┴───────────┴─────────────┘
|
||||||
|
|
||||||
|
Total models: 1
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -34,15 +34,15 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (Meta-Llama-3.1-8B-Instruct)`
|
- `Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-70B-Instruct (Meta-Llama-3.1-70B-Instruct)`
|
- `Meta-Llama-3.1-70B-Instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (Meta-Llama-3.1-405B-Instruct)`
|
- `Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
- `meta-llama/Llama-3.2-1B-Instruct (Meta-Llama-3.2-1B-Instruct)`
|
- `Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-3B-Instruct (Meta-Llama-3.2-3B-Instruct)`
|
- `Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
- `meta-llama/Llama-3.3-70B-Instruct (Meta-Llama-3.3-70B-Instruct)`
|
- `Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct (Llama-3.2-11B-Vision-Instruct)`
|
- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct (Llama-3.2-90B-Vision-Instruct)`
|
- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-Guard-3-8B (Meta-Llama-Guard-3-8B)`
|
- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080}/v1`)
|
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080/v1`)
|
||||||
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
|
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
|
||||||
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,17 +37,17 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct`
|
- `meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-70B-Instruct`
|
- `meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
- `meta-llama/Llama-3.1-405B-Instruct-FP8`
|
- `meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
- `meta-llama/Llama-3.2-3B-Instruct`
|
- `meta-llama/Llama-3.2-3B-Instruct-Turbo (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct`
|
- `meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct`
|
- `meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
- `meta-llama/Llama-3.3-70B-Instruct`
|
- `meta-llama/Llama-3.3-70B-Instruct-Turbo (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `meta-llama/Llama-Guard-3-8B`
|
- `meta-llama/Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||||
- `meta-llama/Llama-Guard-3-11B-Vision`
|
- `meta-llama/Llama-Guard-3-11B-Vision-Turbo (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||||
- `togethercomputer/m2-bert-80M-8k-retrieval`
|
- `togethercomputer/m2-bert-80M-8k-retrieval `
|
||||||
- `togethercomputer/m2-bert-80M-32k-retrieval`
|
- `togethercomputer/m2-bert-80M-32k-retrieval `
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ The API is **exactly identical** for both clients.
|
||||||
:::{dropdown} Starting up the Llama Stack server
|
:::{dropdown} Starting up the Llama Stack server
|
||||||
The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc.
|
The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc.
|
||||||
|
|
||||||
To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image.
|
To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image. If you'd like to build your own image or customize the configurations, please check out [this guide](../references/index.md).
|
||||||
|
|
||||||
Lets setup some environment variables that we will use in the rest of the guide.
|
Lets setup some environment variables that we will use in the rest of the guide.
|
||||||
```bash
|
```bash
|
||||||
|
|
@ -102,12 +102,18 @@ Let's use the `llama-stack-client` CLI to check the connectivity to the server.
|
||||||
$ llama-stack-client configure --endpoint http://localhost:$LLAMA_STACK_PORT
|
$ llama-stack-client configure --endpoint http://localhost:$LLAMA_STACK_PORT
|
||||||
> Enter the API key (leave empty if no key is needed):
|
> Enter the API key (leave empty if no key is needed):
|
||||||
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
|
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
|
||||||
|
|
||||||
$ llama-stack-client models list
|
$ llama-stack-client models list
|
||||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
|
|
||||||
┃ identifier ┃ provider_id ┃ provider_resource_id ┃ metadata ┃
|
Available Models
|
||||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
|
|
||||||
│ meta-llama/Llama-3.2-3B-Instruct │ ollama │ llama3.2:3b-instruct-fp16 │ │
|
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┓
|
||||||
└──────────────────────────────────┴─────────────┴───────────────────────────┴──────────┘
|
┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
|
||||||
|
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━┩
|
||||||
|
│ llm │ meta-llama/Llama-3.2-3B-Instruct │ llama3.2:3b-instruct-fp16 │ │ ollama │
|
||||||
|
└──────────────┴──────────────────────────────────────┴──────────────────────────────┴───────────┴─────────────┘
|
||||||
|
|
||||||
|
Total models: 1
|
||||||
```
|
```
|
||||||
|
|
||||||
You can test basic Llama inference completion using the CLI too.
|
You can test basic Llama inference completion using the CLI too.
|
||||||
|
|
@ -243,7 +249,7 @@ agent_config = AgentConfig(
|
||||||
# Define tools available to the agent
|
# Define tools available to the agent
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
{
|
{
|
||||||
"name": "builtin::rag",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {
|
"args": {
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
},
|
},
|
||||||
|
|
@ -255,7 +261,7 @@ rag_agent = Agent(client, agent_config)
|
||||||
session_id = rag_agent.create_session("test-session")
|
session_id = rag_agent.create_session("test-session")
|
||||||
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
"What are the top 5 topics that were explained? Only list succinct bullet points.",
|
"How to optimize memory usage in torchtune? use the knowledge_search tool to get information.",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Run the agent loop by calling the `create_turn` method
|
# Run the agent loop by calling the `create_turn` method
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
|
|
||||||
```{admonition} News
|
```{admonition} News
|
||||||
:class: tip
|
:class: tip
|
||||||
|
|
||||||
Llama Stack 0.1.3 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.3) for more details.
|
Llama Stack {{ llama_stack_version }} is now available! See the {{ llama_stack_version_link }} for more details.
|
||||||
```
|
```
|
||||||
|
|
||||||
# Llama Stack
|
# Llama Stack
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ response = client.eval.evaluate_rows(
|
||||||
benchmark_id="meta-reference::mmmu",
|
benchmark_id="meta-reference::mmmu",
|
||||||
input_rows=eval_rows,
|
input_rows=eval_rows,
|
||||||
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "model",
|
"type": "model",
|
||||||
|
|
@ -145,7 +145,7 @@ response = client.eval.evaluate_rows(
|
||||||
benchmark_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
input_rows=eval_rows.rows,
|
input_rows=eval_rows.rows,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "model",
|
"type": "model",
|
||||||
|
|
@ -195,7 +195,7 @@ response = client.eval.evaluate_rows(
|
||||||
benchmark_id="meta-reference::simpleqa",
|
benchmark_id="meta-reference::simpleqa",
|
||||||
input_rows=eval_rows.rows,
|
input_rows=eval_rows.rows,
|
||||||
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
task_config={
|
benchmark_config={
|
||||||
"type": "benchmark",
|
"type": "benchmark",
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "agent",
|
"type": "agent",
|
||||||
|
|
|
||||||
|
|
@ -129,3 +129,35 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
|
||||||
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
|
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
|
||||||
|
|
||||||
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
|
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
|
||||||
|
|
||||||
|
## List the downloaded models
|
||||||
|
|
||||||
|
To list the downloaded models with the following command:
|
||||||
|
```
|
||||||
|
llama model list --downloaded
|
||||||
|
```
|
||||||
|
|
||||||
|
You should see a table like this:
|
||||||
|
```
|
||||||
|
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
||||||
|
┃ Model ┃ Size ┃ Modified Time ┃
|
||||||
|
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
||||||
|
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
||||||
|
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -154,6 +154,38 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
|
||||||
|
|
||||||
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
|
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
|
||||||
|
|
||||||
|
## List the downloaded models
|
||||||
|
|
||||||
|
To list the downloaded models with the following command:
|
||||||
|
```
|
||||||
|
llama model list --downloaded
|
||||||
|
```
|
||||||
|
|
||||||
|
You should see a table like this:
|
||||||
|
```
|
||||||
|
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
||||||
|
┃ Model ┃ Size ┃ Modified Time ┃
|
||||||
|
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
||||||
|
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
||||||
|
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
||||||
|
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
||||||
|
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Understand the models
|
## Understand the models
|
||||||
The `llama model` command helps you explore the model’s interface.
|
The `llama model` command helps you explore the model’s interface.
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ The `llama-stack-client` CLI allows you to query information about the distribut
|
||||||
|
|
||||||
### `llama-stack-client`
|
### `llama-stack-client`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client -h
|
llama-stack-client -h
|
||||||
|
|
||||||
usage: llama-stack-client [-h] {models,memory_banks,shields} ...
|
usage: llama-stack-client [-h] {models,memory_banks,shields} ...
|
||||||
|
|
||||||
|
|
@ -21,7 +21,7 @@ subcommands:
|
||||||
|
|
||||||
### `llama-stack-client configure`
|
### `llama-stack-client configure`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client configure
|
llama-stack-client configure
|
||||||
> Enter the host name of the Llama Stack distribution server: localhost
|
> Enter the host name of the Llama Stack distribution server: localhost
|
||||||
> Enter the port number of the Llama Stack distribution server: 8321
|
> Enter the port number of the Llama Stack distribution server: 8321
|
||||||
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
|
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
|
||||||
|
|
@ -29,7 +29,7 @@ Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:
|
||||||
|
|
||||||
### `llama-stack-client providers list`
|
### `llama-stack-client providers list`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client providers list
|
llama-stack-client providers list
|
||||||
```
|
```
|
||||||
```
|
```
|
||||||
+-----------+----------------+-----------------+
|
+-----------+----------------+-----------------+
|
||||||
|
|
@ -55,19 +55,23 @@ $ llama-stack-client providers list
|
||||||
|
|
||||||
### `llama-stack-client models list`
|
### `llama-stack-client models list`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client models list
|
llama-stack-client models list
|
||||||
```
|
```
|
||||||
```
|
```
|
||||||
+----------------------+----------------------+---------------+----------------------------------------------------------+
|
Available Models
|
||||||
| identifier | llama_model | provider_id | metadata |
|
|
||||||
+======================+======================+===============+==========================================================+
|
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┓
|
||||||
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | tgi0 | {'huggingface_repo': 'meta-llama/Llama-3.1-8B-Instruct'} |
|
┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
|
||||||
+----------------------+----------------------+---------------+----------------------------------------------------------+
|
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━┩
|
||||||
|
│ llm │ meta-llama/Llama-3.2-3B-Instruct │ llama3.2:3b-instruct-fp16 │ │ ollama │
|
||||||
|
└──────────────┴──────────────────────────────────────┴──────────────────────────────┴───────────┴─────────────┘
|
||||||
|
|
||||||
|
Total models: 1
|
||||||
```
|
```
|
||||||
|
|
||||||
### `llama-stack-client models get`
|
### `llama-stack-client models get`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client models get Llama3.1-8B-Instruct
|
llama-stack-client models get Llama3.1-8B-Instruct
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
@ -80,7 +84,7 @@ $ llama-stack-client models get Llama3.1-8B-Instruct
|
||||||
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client models get Random-Model
|
llama-stack-client models get Random-Model
|
||||||
|
|
||||||
Model RandomModel is not found at distribution endpoint host:port. Please ensure endpoint is serving specified model.
|
Model RandomModel is not found at distribution endpoint host:port. Please ensure endpoint is serving specified model.
|
||||||
```
|
```
|
||||||
|
|
@ -88,26 +92,26 @@ Model RandomModel is not found at distribution endpoint host:port. Please ensure
|
||||||
### `llama-stack-client models register`
|
### `llama-stack-client models register`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client models register <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>]
|
llama-stack-client models register <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>]
|
||||||
```
|
```
|
||||||
|
|
||||||
### `llama-stack-client models update`
|
### `llama-stack-client models update`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client models update <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>]
|
llama-stack-client models update <model_id> [--provider-id <provider_id>] [--provider-model-id <provider_model_id>] [--metadata <metadata>]
|
||||||
```
|
```
|
||||||
|
|
||||||
### `llama-stack-client models delete`
|
### `llama-stack-client models delete`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client models delete <model_id>
|
llama-stack-client models delete <model_id>
|
||||||
```
|
```
|
||||||
|
|
||||||
## Vector DB Management
|
## Vector DB Management
|
||||||
|
|
||||||
### `llama-stack-client vector_dbs list`
|
### `llama-stack-client vector_dbs list`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client vector_dbs list
|
llama-stack-client vector_dbs list
|
||||||
```
|
```
|
||||||
```
|
```
|
||||||
+--------------+----------------+---------------------+---------------+------------------------+
|
+--------------+----------------+---------------------+---------------+------------------------+
|
||||||
|
|
@ -120,7 +124,7 @@ $ llama-stack-client vector_dbs list
|
||||||
|
|
||||||
### `llama-stack-client vector_dbs register`
|
### `llama-stack-client vector_dbs register`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>]
|
llama-stack-client vector_dbs register <vector-db-id> [--provider-id <provider-id>] [--provider-vector-db-id <provider-vector-db-id>] [--embedding-model <embedding-model>] [--embedding-dimension <embedding-dimension>]
|
||||||
```
|
```
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
|
|
@ -131,13 +135,13 @@ Options:
|
||||||
|
|
||||||
### `llama-stack-client vector_dbs unregister`
|
### `llama-stack-client vector_dbs unregister`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client vector_dbs unregister <vector-db-id>
|
llama-stack-client vector_dbs unregister <vector-db-id>
|
||||||
```
|
```
|
||||||
|
|
||||||
## Shield Management
|
## Shield Management
|
||||||
### `llama-stack-client shields list`
|
### `llama-stack-client shields list`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client shields list
|
llama-stack-client shields list
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
@ -150,7 +154,7 @@ $ llama-stack-client shields list
|
||||||
|
|
||||||
### `llama-stack-client shields register`
|
### `llama-stack-client shields register`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client shields register --shield-id <shield-id> [--provider-id <provider-id>] [--provider-shield-id <provider-shield-id>] [--params <params>]
|
llama-stack-client shields register --shield-id <shield-id> [--provider-id <provider-id>] [--provider-shield-id <provider-shield-id>] [--params <params>]
|
||||||
```
|
```
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
|
|
@ -163,12 +167,12 @@ Options:
|
||||||
|
|
||||||
### `llama-stack-client benchmarks list`
|
### `llama-stack-client benchmarks list`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client benchmarks list
|
llama-stack-client benchmarks list
|
||||||
```
|
```
|
||||||
|
|
||||||
### `llama-stack-client benchmarks register`
|
### `llama-stack-client benchmarks register`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
|
llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
|
||||||
```
|
```
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
|
|
@ -182,7 +186,7 @@ Options:
|
||||||
## Eval execution
|
## Eval execution
|
||||||
### `llama-stack-client eval run-benchmark`
|
### `llama-stack-client eval run-benchmark`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
|
llama-stack-client eval run-benchmark <eval-task-id1> [<eval-task-id2> ...] --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
|
||||||
```
|
```
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
|
|
@ -207,7 +211,7 @@ Example benchmark_config.json:
|
||||||
|
|
||||||
### `llama-stack-client eval run-scoring`
|
### `llama-stack-client eval run-scoring`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client eval run-scoring <eval-task-id> --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
|
llama-stack-client eval run-scoring <eval-task-id> --eval-task-config <config-file> --output-dir <output-dir> [--num-examples <num>] [--visualize]
|
||||||
```
|
```
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
|
|
@ -220,7 +224,7 @@ Options:
|
||||||
|
|
||||||
### `llama-stack-client toolgroups list`
|
### `llama-stack-client toolgroups list`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client toolgroups list
|
llama-stack-client toolgroups list
|
||||||
```
|
```
|
||||||
```
|
```
|
||||||
+---------------------------+------------------+------+---------------+
|
+---------------------------+------------------+------+---------------+
|
||||||
|
|
@ -236,14 +240,14 @@ $ llama-stack-client toolgroups list
|
||||||
|
|
||||||
### `llama-stack-client toolgroups get`
|
### `llama-stack-client toolgroups get`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client toolgroups get <toolgroup_id>
|
llama-stack-client toolgroups get <toolgroup_id>
|
||||||
```
|
```
|
||||||
|
|
||||||
Shows detailed information about a specific toolgroup. If the toolgroup is not found, displays an error message.
|
Shows detailed information about a specific toolgroup. If the toolgroup is not found, displays an error message.
|
||||||
|
|
||||||
### `llama-stack-client toolgroups register`
|
### `llama-stack-client toolgroups register`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client toolgroups register <toolgroup_id> [--provider-id <provider-id>] [--provider-toolgroup-id <provider-toolgroup-id>] [--mcp-config <mcp-config>] [--args <args>]
|
llama-stack-client toolgroups register <toolgroup_id> [--provider-id <provider-id>] [--provider-toolgroup-id <provider-toolgroup-id>] [--mcp-config <mcp-config>] [--args <args>]
|
||||||
```
|
```
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
|
|
@ -254,5 +258,5 @@ Options:
|
||||||
|
|
||||||
### `llama-stack-client toolgroups unregister`
|
### `llama-stack-client toolgroups unregister`
|
||||||
```bash
|
```bash
|
||||||
$ llama-stack-client toolgroups unregister <toolgroup_id>
|
llama-stack-client toolgroups unregister <toolgroup_id>
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ This guide will walk you through an end-to-end workflow with Llama Stack with Ol
|
||||||
|
|
||||||
If you're looking for more specific topics, we have a [Zero to Hero Guide](#next-steps) that covers everything from Tool Calling to Agents in detail. Feel free to skip to the end to explore the advanced topics you're interested in.
|
If you're looking for more specific topics, we have a [Zero to Hero Guide](#next-steps) that covers everything from Tool Calling to Agents in detail. Feel free to skip to the end to explore the advanced topics you're interested in.
|
||||||
|
|
||||||
> If you'd prefer not to set up a local server, explore our notebook on [tool calling with the Together API](Tool_Calling101_Using_Together's_Llama_Stack_Server.ipynb). This notebook will show you how to leverage together.ai's Llama Stack Server API, allowing you to get started with Llama Stack without the need for a locally built and running server.
|
> If you'd prefer not to set up a local server, explore our notebook on [tool calling with the Together API](Tool_Calling101_Using_Together_Llama_Stack_Server.ipynb). This notebook will show you how to leverage together.ai's Llama Stack Server API, allowing you to get started with Llama Stack without the need for a locally built and running server.
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
1. [Setup and run ollama](#setup-ollama)
|
1. [Setup and run ollama](#setup-ollama)
|
||||||
|
|
@ -73,7 +73,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
||||||
Open a new terminal and install `llama-stack`:
|
Open a new terminal and install `llama-stack`:
|
||||||
```bash
|
```bash
|
||||||
conda activate ollama
|
conda activate ollama
|
||||||
pip install llama-stack==0.1.0
|
pip install -U llama-stack
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
|
||||||
|
|
@ -296,9 +296,6 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
tool_config: Optional[ToolConfig] = None
|
tool_config: Optional[ToolConfig] = None
|
||||||
|
|
||||||
# TODO (xiyan): temporary flag, will remove for 0.1.5
|
|
||||||
allow_turn_resume: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResumeRequest(BaseModel):
|
class AgentTurnResumeRequest(BaseModel):
|
||||||
|
|
@ -355,7 +352,6 @@ class Agents(Protocol):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
toolgroups: Optional[List[AgentToolGroup]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
allow_turn_resume: Optional[bool] = False,
|
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
@ -33,3 +36,20 @@ class Api(Enum):
|
||||||
|
|
||||||
# built-in API
|
# built-in API
|
||||||
inspect = "inspect"
|
inspect = "inspect"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Error(BaseModel):
|
||||||
|
"""
|
||||||
|
Error response from the API. Roughly follows RFC 7807.
|
||||||
|
|
||||||
|
:param status: HTTP status code
|
||||||
|
:param title: Error title, a short summary of the error which is invariant for an error type
|
||||||
|
:param detail: Error detail, a longer human-readable description of the error
|
||||||
|
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: int
|
||||||
|
title: str
|
||||||
|
detail: str
|
||||||
|
instance: Optional[str] = None
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ class Eval(Protocol):
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
task_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||||
|
|
@ -72,7 +72,7 @@ class Eval(Protocol):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
task_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse: ...
|
) -> EvaluateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||||
|
|
|
||||||
|
|
@ -343,7 +343,7 @@ def _hf_download(
|
||||||
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
"You can find your token by visiting https://huggingface.co/settings/tokens"
|
||||||
)
|
)
|
||||||
except RepositoryNotFoundError:
|
except RepositoryNotFoundError:
|
||||||
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub.")
|
parser.error(f"Repository '{repo_id}' not found on the Hugging Face Hub or incorrect Hugging Face token.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
parser.error(e)
|
parser.error(e)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import argparse
|
||||||
from .download import Download
|
from .download import Download
|
||||||
from .model import ModelParser
|
from .model import ModelParser
|
||||||
from .stack import StackParser
|
from .stack import StackParser
|
||||||
|
from .stack.utils import print_subcommand_description
|
||||||
from .verify_download import VerifyDownload
|
from .verify_download import VerifyDownload
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,6 +21,7 @@ class LlamaCLIParser:
|
||||||
prog="llama",
|
prog="llama",
|
||||||
description="Welcome to the Llama CLI",
|
description="Welcome to the Llama CLI",
|
||||||
add_help=True,
|
add_help=True,
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Default command is to print help
|
# Default command is to print help
|
||||||
|
|
@ -33,6 +35,8 @@ class LlamaCLIParser:
|
||||||
Download.create(subparsers)
|
Download.create(subparsers)
|
||||||
VerifyDownload.create(subparsers)
|
VerifyDownload.create(subparsers)
|
||||||
|
|
||||||
|
print_subcommand_description(self.parser, subparsers)
|
||||||
|
|
||||||
def parse_args(self) -> argparse.Namespace:
|
def parse_args(self) -> argparse.Namespace:
|
||||||
return self.parser.parse_args()
|
return self.parser.parse_args()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from termcolor import colored
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
|
@ -52,11 +50,12 @@ class ModelDescribe(Subcommand):
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
headers = [
|
||||||
|
"Model",
|
||||||
|
model.descriptor(),
|
||||||
|
]
|
||||||
|
|
||||||
rows = [
|
rows = [
|
||||||
(
|
|
||||||
colored("Model", "white", attrs=["bold"]),
|
|
||||||
colored(model.descriptor(), "white", attrs=["bold"]),
|
|
||||||
),
|
|
||||||
("Hugging Face ID", model.huggingface_repo or "<Not Available>"),
|
("Hugging Face ID", model.huggingface_repo or "<Not Available>"),
|
||||||
("Description", model.description),
|
("Description", model.description),
|
||||||
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
|
("Context Length", f"{model.max_seq_length // 1024}K tokens"),
|
||||||
|
|
@ -77,5 +76,6 @@ class ModelDescribe(Subcommand):
|
||||||
|
|
||||||
print_table(
|
print_table(
|
||||||
rows,
|
rows,
|
||||||
|
headers,
|
||||||
separate_rows=True,
|
separate_rows=True,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from llama_stack.cli.model.list import ModelList
|
||||||
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
from llama_stack.cli.model.prompt_format import ModelPromptFormat
|
||||||
from llama_stack.cli.model.remove import ModelRemove
|
from llama_stack.cli.model.remove import ModelRemove
|
||||||
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
from llama_stack.cli.model.verify_download import ModelVerifyDownload
|
||||||
|
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -24,6 +25,7 @@ class ModelParser(Subcommand):
|
||||||
"model",
|
"model",
|
||||||
prog="llama model",
|
prog="llama model",
|
||||||
description="Work with llama models",
|
description="Work with llama models",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
self.parser.set_defaults(func=lambda args: self.parser.print_help())
|
||||||
|
|
@ -37,3 +39,5 @@ class ModelParser(Subcommand):
|
||||||
ModelDescribe.create(subparsers)
|
ModelDescribe.create(subparsers)
|
||||||
ModelVerifyDownload.create(subparsers)
|
ModelVerifyDownload.create(subparsers)
|
||||||
ModelRemove.create(subparsers)
|
ModelRemove.create(subparsers)
|
||||||
|
|
||||||
|
print_subcommand_description(self.parser, subparsers)
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,14 @@
|
||||||
import argparse
|
import argparse
|
||||||
import textwrap
|
import textwrap
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
from llama_stack.cli.table import print_table
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
|
||||||
class ModelPromptFormat(Subcommand):
|
class ModelPromptFormat(Subcommand):
|
||||||
"""Llama model cli for describe a model prompt format (message formats)"""
|
"""Llama model cli for describe a model prompt format (message formats)"""
|
||||||
|
|
@ -48,7 +52,26 @@ class ModelPromptFormat(Subcommand):
|
||||||
supported_model_ids = [
|
supported_model_ids = [
|
||||||
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
|
||||||
]
|
]
|
||||||
model_str = "\n".join([m.value for m in supported_model_ids])
|
|
||||||
|
model_list = [m.value for m in supported_model_ids]
|
||||||
|
model_str = "\n".join(model_list)
|
||||||
|
|
||||||
|
if args.list:
|
||||||
|
headers = ["Model(s)"]
|
||||||
|
rows = []
|
||||||
|
for m in model_list:
|
||||||
|
rows.append(
|
||||||
|
[
|
||||||
|
m,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print_table(
|
||||||
|
rows,
|
||||||
|
headers,
|
||||||
|
separate_rows=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_id = CoreModelId(args.model_name)
|
model_id = CoreModelId(args.model_name)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|
@ -57,9 +80,9 @@ class ModelPromptFormat(Subcommand):
|
||||||
if model_id not in supported_model_ids:
|
if model_id not in supported_model_ids:
|
||||||
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
self.parser.error(f"{model_id} is not a valid Model. Choose one from --\n {model_str}")
|
||||||
|
|
||||||
llama_3_1_file = importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
|
llama_3_1_file = ROOT_DIR / "models" / "llama" / "llama3_1" / "prompt_format.md"
|
||||||
llama_3_2_text_file = importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
|
llama_3_2_text_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "text_prompt_format.md"
|
||||||
llama_3_2_vision_file = importlib.resources.files("llama_models") / "llama3_2/vision_prompt_format.md"
|
llama_3_2_vision_file = ROOT_DIR / "models" / "llama" / "llama3_2" / "vision_prompt_format.md"
|
||||||
if model_family(model_id) == ModelFamily.llama3_1:
|
if model_family(model_id) == ModelFamily.llama3_1:
|
||||||
with importlib.resources.as_file(llama_3_1_file) as f:
|
with importlib.resources.as_file(llama_3_1_file) as f:
|
||||||
content = f.open("r").read()
|
content = f.open("r").read()
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.distribution.utils.exec import formulate_run_args, in_notebook, run_with_pty
|
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||||
from llama_stack.distribution.utils.image_types import ImageType
|
from llama_stack.distribution.utils.image_types import ImageType
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
@ -65,8 +65,6 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
if args.image_type == "venv":
|
if args.image_type == "venv":
|
||||||
current_venv = os.environ.get("VIRTUAL_ENV")
|
current_venv = os.environ.get("VIRTUAL_ENV")
|
||||||
image_name = args.image_name or current_venv
|
image_name = args.image_name or current_venv
|
||||||
if not image_name and in_notebook():
|
|
||||||
image_name = "__system__"
|
|
||||||
elif args.image_type == "conda":
|
elif args.image_type == "conda":
|
||||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||||
image_name = args.image_name or current_conda_env
|
image_name = args.image_name or current_conda_env
|
||||||
|
|
@ -143,7 +141,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
completer=WordCompleter(available_providers),
|
completer=WordCompleter(available_providers),
|
||||||
complete_while_typing=True,
|
complete_while_typing=True,
|
||||||
validator=Validator.from_callable(
|
validator=Validator.from_callable(
|
||||||
lambda x: x in available_providers,
|
lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
|
||||||
error_message="Invalid provider, use <TAB> to see options",
|
error_message="Invalid provider, use <TAB> to see options",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -250,7 +248,7 @@ def _generate_run_config(
|
||||||
|
|
||||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||||
if hasattr(config_type, "sample_run_config"):
|
if hasattr(config_type, "sample_run_config"):
|
||||||
config = config_type.sample_run_config(__distro_dir__=f"distributions/{image_name}")
|
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
|
||||||
else:
|
else:
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
|
|
@ -291,6 +289,8 @@ def _run_stack_build_command_from_build_config(
|
||||||
if not image_name:
|
if not image_name:
|
||||||
raise ValueError("Please specify an image name when building a conda image")
|
raise ValueError("Please specify an image name when building a conda image")
|
||||||
elif build_config.image_type == ImageType.venv.value:
|
elif build_config.image_type == ImageType.venv.value:
|
||||||
|
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
|
||||||
|
image_name = "__system__"
|
||||||
if not image_name:
|
if not image_name:
|
||||||
raise ValueError("Please specify an image name when building a venv image")
|
raise ValueError("Please specify an image name when building a venv image")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ class StackBuild(Subcommand):
|
||||||
"--config",
|
"--config",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
|
help="Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
@ -12,6 +13,8 @@ from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StackRun(Subcommand):
|
class StackRun(Subcommand):
|
||||||
def __init__(self, subparsers: argparse._SubParsersAction):
|
def __init__(self, subparsers: argparse._SubParsersAction):
|
||||||
|
|
@ -34,7 +37,7 @@ class StackRun(Subcommand):
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--port",
|
"--port",
|
||||||
type=int,
|
type=int,
|
||||||
help="Port to run the server on. Defaults to 8321",
|
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. Defaults to 8321",
|
||||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
|
|
@ -75,7 +78,6 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import yaml
|
import yaml
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack.distribution.build import ImageType
|
from llama_stack.distribution.build import ImageType
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
|
|
@ -85,10 +87,6 @@ class StackRun(Subcommand):
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
|
||||||
|
|
||||||
if not args.config:
|
|
||||||
self.parser.error("Must specify a config file to run")
|
|
||||||
return
|
|
||||||
|
|
||||||
config_file = Path(args.config)
|
config_file = Path(args.config)
|
||||||
has_yaml_suffix = args.config.endswith(".yaml")
|
has_yaml_suffix = args.config.endswith(".yaml")
|
||||||
template_name = None
|
template_name = None
|
||||||
|
|
@ -115,11 +113,23 @@ class StackRun(Subcommand):
|
||||||
self.parser.error(
|
self.parser.error(
|
||||||
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
print(f"Using run configuration: {config_file}")
|
if not config_file.is_file():
|
||||||
config_dict = yaml.safe_load(config_file.read_text())
|
self.parser.error(
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
f"Config file must be a valid file path, '{config_file}’ is not a file: type={type(config_file)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Using run configuration: {config_file}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_dict = yaml.safe_load(config_file.read_text())
|
||||||
|
except yaml.parser.ParserError as e:
|
||||||
|
self.parser.error(f"failed to load config file '{config_file}':\n {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
except AttributeError as e:
|
||||||
|
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||||
|
|
||||||
run_args = formulate_run_args(args.image_type, args.image_name, config, template_name)
|
run_args = formulate_run_args(args.image_type, args.image_name, config, template_name)
|
||||||
|
|
||||||
|
|
@ -129,18 +139,10 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
for env_var in args.env:
|
for env_var in args.env:
|
||||||
if "=" not in env_var:
|
if "=" not in env_var:
|
||||||
cprint(
|
self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
|
||||||
f"Environment variable '{env_var}' must be in KEY=VALUE format",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
key, value = env_var.split("=", 1) # split on first = only
|
key, value = env_var.split("=", 1) # split on first = only
|
||||||
if not key:
|
if not key:
|
||||||
cprint(
|
self.parser.error(f"Environment variable '{env_var}' has empty key")
|
||||||
f"Environment variable '{env_var}' has empty key",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
run_args.extend(["--env", f"{key}={value}"])
|
run_args.extend(["--env", f"{key}={value}"])
|
||||||
|
|
||||||
if args.tls_keyfile and args.tls_certfile:
|
if args.tls_keyfile and args.tls_certfile:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
|
||||||
|
from llama_stack.cli.stack.utils import print_subcommand_description
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
from .build import StackBuild
|
from .build import StackBuild
|
||||||
|
|
@ -22,6 +23,7 @@ class StackParser(Subcommand):
|
||||||
"stack",
|
"stack",
|
||||||
prog="llama stack",
|
prog="llama stack",
|
||||||
description="Operations for the Llama Stack / Distributions",
|
description="Operations for the Llama Stack / Distributions",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
|
|
@ -39,3 +41,5 @@ class StackParser(Subcommand):
|
||||||
StackListApis.create(subparsers)
|
StackListApis.create(subparsers)
|
||||||
StackListProviders.create(subparsers)
|
StackListProviders.create(subparsers)
|
||||||
StackRun.create(subparsers)
|
StackRun.create(subparsers)
|
||||||
|
|
||||||
|
print_subcommand_description(self.parser, subparsers)
|
||||||
|
|
|
||||||
14
llama_stack/cli/stack/utils.py
Normal file
14
llama_stack/cli/stack/utils.py
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
def print_subcommand_description(parser, subparsers):
|
||||||
|
"""Print descriptions of subcommands."""
|
||||||
|
description_text = ""
|
||||||
|
for name, subcommand in subparsers.choices.items():
|
||||||
|
description = subcommand.description
|
||||||
|
description_text += f" {name:<21} {description}\n"
|
||||||
|
parser.epilog = description_text
|
||||||
|
|
@ -15,7 +15,6 @@ from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
|
||||||
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
from llama_stack.distribution.utils.exec import run_command, run_with_pty
|
||||||
from llama_stack.distribution.utils.image_types import ImageType
|
from llama_stack.distribution.utils.image_types import ImageType
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
@ -103,8 +102,6 @@ def build_image(
|
||||||
template_or_config,
|
template_or_config,
|
||||||
image_name,
|
image_name,
|
||||||
container_base,
|
container_base,
|
||||||
str(build_file_path),
|
|
||||||
str(BUILDS_BASE_DIR / ImageType.container.value),
|
|
||||||
" ".join(normal_deps),
|
" ".join(normal_deps),
|
||||||
]
|
]
|
||||||
elif build_config.image_type == ImageType.conda.value:
|
elif build_config.image_type == ImageType.conda.value:
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||||
|
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||||
|
|
@ -16,8 +16,8 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||||
fi
|
fi
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$#" -lt 3 ]; then
|
if [ "$#" -lt 3 ]; then
|
||||||
|
|
@ -52,7 +52,7 @@ ensure_conda_env_python310() {
|
||||||
local python_version="3.10"
|
local python_version="3.10"
|
||||||
|
|
||||||
# Check if conda command is available
|
# Check if conda command is available
|
||||||
if ! command -v conda &>/dev/null; then
|
if ! is_command_available conda; then
|
||||||
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
@ -87,8 +87,6 @@ ensure_conda_env_python310() {
|
||||||
# these packages are damaged in test-pypi, so install them first
|
# these packages are damaged in test-pypi, so install them first
|
||||||
uv pip install fastapi libcst
|
uv pip install fastapi libcst
|
||||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||||
llama-models==$TEST_PYPI_VERSION \
|
|
||||||
llama-stack-client==$TEST_PYPI_VERSION \
|
|
||||||
llama-stack==$TEST_PYPI_VERSION \
|
llama-stack==$TEST_PYPI_VERSION \
|
||||||
$pip_dependencies
|
$pip_dependencies
|
||||||
if [ -n "$special_pip_deps" ]; then
|
if [ -n "$special_pip_deps" ]; then
|
||||||
|
|
@ -111,22 +109,21 @@ ensure_conda_env_python310() {
|
||||||
else
|
else
|
||||||
PYPI_VERSION="${PYPI_VERSION:-}"
|
PYPI_VERSION="${PYPI_VERSION:-}"
|
||||||
if [ -n "$PYPI_VERSION" ]; then
|
if [ -n "$PYPI_VERSION" ]; then
|
||||||
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}"
|
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
|
||||||
else
|
else
|
||||||
SPEC_VERSION="llama-stack"
|
SPEC_VERSION="llama-stack"
|
||||||
fi
|
fi
|
||||||
uv pip install --no-cache-dir $SPEC_VERSION
|
uv pip install --no-cache-dir $SPEC_VERSION
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
|
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
|
printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
|
||||||
uv pip uninstall llama-models
|
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||||
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Install pip dependencies
|
# Install pip dependencies
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
#!/bin/bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
@ -6,7 +6,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||||
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||||
|
|
||||||
|
|
@ -20,26 +19,27 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||||
# mounting is not supported by docker buildx, so we use COPY instead
|
# mounting is not supported by docker buildx, so we use COPY instead
|
||||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||||
|
|
||||||
if [ "$#" -lt 6 ]; then
|
if [ "$#" -lt 4 ]; then
|
||||||
# This only works for templates
|
# This only works for templates
|
||||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <build_file_path> <host_build_dir> <pip_dependencies> [<special_pip_deps>]" >&2
|
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
template_or_config="$1"
|
template_or_config="$1"
|
||||||
image_name="$2"
|
shift
|
||||||
container_base="$3"
|
image_name="$1"
|
||||||
build_file_path="$4"
|
shift
|
||||||
host_build_dir="$5"
|
container_base="$1"
|
||||||
pip_dependencies="$6"
|
shift
|
||||||
special_pip_deps="${7:-}"
|
pip_dependencies="$1"
|
||||||
|
shift
|
||||||
|
special_pip_deps="${1:-}"
|
||||||
|
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
GREEN='\033[0;32m'
|
|
||||||
NC='\033[0m' # No Color
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
||||||
|
|
@ -47,8 +47,10 @@ CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
||||||
|
|
||||||
TEMP_DIR=$(mktemp -d)
|
TEMP_DIR=$(mktemp -d)
|
||||||
|
|
||||||
|
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||||
|
source "$SCRIPT_DIR/common.sh"
|
||||||
|
|
||||||
add_to_container() {
|
add_to_container() {
|
||||||
local input
|
|
||||||
output_file="$TEMP_DIR/Containerfile"
|
output_file="$TEMP_DIR/Containerfile"
|
||||||
if [ -t 0 ]; then
|
if [ -t 0 ]; then
|
||||||
printf '%s\n' "$1" >>"$output_file"
|
printf '%s\n' "$1" >>"$output_file"
|
||||||
|
|
@ -58,15 +60,21 @@ add_to_container() {
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Check if container command is available
|
||||||
|
if ! is_command_available $CONTAINER_BINARY; then
|
||||||
|
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
# Update and install UBI9 components if UBI9 base image is used
|
# Update and install UBI9 components if UBI9 base image is used
|
||||||
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
FROM $container_base
|
FROM $container_base
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
RUN microdnf -y update && microdnf install -y iputils net-tools wget \
|
RUN dnf -y update && dnf install -y iputils net-tools wget \
|
||||||
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
||||||
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all
|
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
||||||
|
|
||||||
ENV UV_SYSTEM_PYTHON=1
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
RUN pip install uv
|
RUN pip install uv
|
||||||
|
|
@ -107,7 +115,6 @@ EOF
|
||||||
fi
|
fi
|
||||||
|
|
||||||
stack_mount="/app/llama-stack-source"
|
stack_mount="/app/llama-stack-source"
|
||||||
models_mount="/app/llama-models-source"
|
|
||||||
client_mount="/app/llama-stack-client-source"
|
client_mount="/app/llama-stack-client-source"
|
||||||
|
|
||||||
install_local_package() {
|
install_local_package() {
|
||||||
|
|
@ -131,10 +138,6 @@ EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|
||||||
install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
|
install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR"
|
||||||
fi
|
fi
|
||||||
|
|
@ -150,12 +153,12 @@ EOF
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
|
||||||
--index-strategy unsafe-best-match \
|
--index-strategy unsafe-best-match \
|
||||||
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
|
llama-stack==$TEST_PYPI_VERSION
|
||||||
|
|
||||||
EOF
|
EOF
|
||||||
else
|
else
|
||||||
if [ -n "$PYPI_VERSION" ]; then
|
if [ -n "$PYPI_VERSION" ]; then
|
||||||
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}"
|
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
|
||||||
else
|
else
|
||||||
SPEC_VERSION="llama-stack"
|
SPEC_VERSION="llama-stack"
|
||||||
fi
|
fi
|
||||||
|
|
@ -165,6 +168,11 @@ EOF
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# remove uv after installation
|
||||||
|
add_to_container << EOF
|
||||||
|
RUN pip uninstall -y uv
|
||||||
|
EOF
|
||||||
|
|
||||||
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
||||||
if [[ "$template_or_config" != *.yaml ]]; then
|
if [[ "$template_or_config" != *.yaml ]]; then
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
|
|
@ -185,26 +193,28 @@ RUN mkdir -p /.llama /.cache
|
||||||
RUN chmod -R g+rw /app /.llama /.cache
|
RUN chmod -R g+rw /app /.llama /.cache
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n"
|
printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
|
||||||
cat $TEMP_DIR/Containerfile
|
cat "$TEMP_DIR"/Containerfile
|
||||||
printf "\n"
|
printf "\n"
|
||||||
|
|
||||||
mounts=""
|
# Start building the CLI arguments
|
||||||
|
CLI_ARGS=()
|
||||||
|
|
||||||
|
# Read CONTAINER_OPTS and put it in an array
|
||||||
|
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
|
||||||
|
|
||||||
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
|
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount"
|
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
|
||||||
fi
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
|
||||||
fi
|
fi
|
||||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount"
|
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
|
if is_command_available selinuxenabled && selinuxenabled; then
|
||||||
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
|
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
|
||||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
CLI_ARGS+=("--security-opt" "label=disable")
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Set version tag based on PyPI version
|
# Set version tag based on PyPI version
|
||||||
|
|
@ -212,7 +222,7 @@ if [ -n "$PYPI_VERSION" ]; then
|
||||||
version_tag="$PYPI_VERSION"
|
version_tag="$PYPI_VERSION"
|
||||||
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
elif [ -n "$TEST_PYPI_VERSION" ]; then
|
||||||
version_tag="test-$TEST_PYPI_VERSION"
|
version_tag="test-$TEST_PYPI_VERSION"
|
||||||
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_MODELS_DIR" ]]; then
|
elif [[ -n "$LLAMA_STACK_DIR" || -n "$LLAMA_STACK_CLIENT_DIR" ]]; then
|
||||||
version_tag="dev"
|
version_tag="dev"
|
||||||
else
|
else
|
||||||
URL="https://pypi.org/pypi/llama-stack/json"
|
URL="https://pypi.org/pypi/llama-stack/json"
|
||||||
|
|
@ -225,11 +235,11 @@ image_tag="$image_name:$version_tag"
|
||||||
# Detect platform architecture
|
# Detect platform architecture
|
||||||
ARCH=$(uname -m)
|
ARCH=$(uname -m)
|
||||||
if [ -n "$BUILD_PLATFORM" ]; then
|
if [ -n "$BUILD_PLATFORM" ]; then
|
||||||
PLATFORM="--platform $BUILD_PLATFORM"
|
CLI_ARGS+=("--platform $BUILD_PLATFORM")
|
||||||
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||||
PLATFORM="--platform linux/arm64"
|
CLI_ARGS+=("--platform" "linux/arm64")
|
||||||
elif [ "$ARCH" = "x86_64" ]; then
|
elif [ "$ARCH" = "x86_64" ]; then
|
||||||
PLATFORM="--platform linux/amd64"
|
CLI_ARGS+=("--platform" "linux/amd64")
|
||||||
else
|
else
|
||||||
echo "Unsupported architecture: $ARCH"
|
echo "Unsupported architecture: $ARCH"
|
||||||
exit 1
|
exit 1
|
||||||
|
|
@ -238,8 +248,13 @@ fi
|
||||||
echo "PWD: $(pwd)"
|
echo "PWD: $(pwd)"
|
||||||
echo "Containerfile: $TEMP_DIR/Containerfile"
|
echo "Containerfile: $TEMP_DIR/Containerfile"
|
||||||
set -x
|
set -x
|
||||||
$CONTAINER_BINARY build $CONTAINER_OPTS $PLATFORM -t $image_tag \
|
|
||||||
-f "$TEMP_DIR/Containerfile" "." $mounts --progress=plain
|
$CONTAINER_BINARY build \
|
||||||
|
"${CLI_ARGS[@]}" \
|
||||||
|
-t "$image_tag" \
|
||||||
|
-f "$TEMP_DIR/Containerfile" \
|
||||||
|
"." \
|
||||||
|
--progress=plain
|
||||||
|
|
||||||
# clean up tmp/configs
|
# clean up tmp/configs
|
||||||
set +x
|
set +x
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@
|
||||||
# TODO: combine this with build_conda_env.sh since it is almost identical
|
# TODO: combine this with build_conda_env.sh since it is almost identical
|
||||||
# the only difference is that we don't do any conda-specific setup
|
# the only difference is that we don't do any conda-specific setup
|
||||||
|
|
||||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||||
|
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
|
||||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||||
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
|
||||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||||
|
|
@ -21,13 +21,13 @@ VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||||
fi
|
fi
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$#" -lt 2 ]; then
|
if [ "$#" -lt 2 ]; then
|
||||||
echo "Usage: $0 <distribution_type> <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
echo "Usage: $0 <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||||
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
echo "Example: $0 mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
@ -95,7 +95,7 @@ run() {
|
||||||
# we are building a command line so word splitting is expected
|
# we are building a command line so word splitting is expected
|
||||||
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
uv pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||||
--index-strategy unsafe-best-match \
|
--index-strategy unsafe-best-match \
|
||||||
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
|
llama-stack=="$TEST_PYPI_VERSION" \
|
||||||
$pip_dependencies
|
$pip_dependencies
|
||||||
if [ -n "$special_pip_deps" ]; then
|
if [ -n "$special_pip_deps" ]; then
|
||||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||||
|
|
@ -120,15 +120,14 @@ run() {
|
||||||
uv pip install --no-cache-dir llama-stack
|
uv pip install --no-cache-dir llama-stack
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
|
||||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_MODELS_DIR" >&2
|
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
printf "Installing from LLAMA_MODELS_DIR: %s\n" "$LLAMA_MODELS_DIR"
|
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
|
||||||
uv pip uninstall llama-models
|
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
|
||||||
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Install pip dependencies
|
# Install pip dependencies
|
||||||
|
|
|
||||||
|
|
@ -1,47 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
|
|
||||||
CONTAINER_OPTS=${CONTAINER_OPTS:-}
|
|
||||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
error_handler() {
|
|
||||||
echo "Error occurred in script at line: ${1}" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
trap 'error_handler ${LINENO}' ERR
|
|
||||||
|
|
||||||
if [ $# -lt 2 ]; then
|
|
||||||
echo "Usage: $0 <container name> <build file path>"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
container_image="$1"
|
|
||||||
host_build_dir="$2"
|
|
||||||
container_build_dir="/app/builds"
|
|
||||||
|
|
||||||
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
|
|
||||||
# Disable SELinux labels
|
|
||||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
|
||||||
fi
|
|
||||||
|
|
||||||
mounts=""
|
|
||||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
|
||||||
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):/app/llama-stack-source"
|
|
||||||
fi
|
|
||||||
|
|
||||||
set -x
|
|
||||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
|
||||||
--entrypoint "/usr/local/bin/llama" \
|
|
||||||
-v $host_build_dir:$container_build_dir \
|
|
||||||
$mounts \
|
|
||||||
$container_image \
|
|
||||||
stack configure ./llamastack-build.yaml --output-dir $container_build_dir
|
|
||||||
|
|
@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> List[Api]:
|
def stack_apis() -> List[Api]:
|
||||||
return [v for v in Api]
|
return list(Api)
|
||||||
|
|
||||||
|
|
||||||
class AutoRoutedApiInfo(BaseModel):
|
class AutoRoutedApiInfo(BaseModel):
|
||||||
|
|
@ -55,7 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
|
|
||||||
|
|
||||||
def providable_apis() -> List[Api]:
|
def providable_apis() -> List[Api]:
|
||||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||||
)
|
)
|
||||||
return value
|
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
|
|
@ -324,6 +324,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||||
|
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
status_code=httpx.codes.OK,
|
status_code=httpx.codes.OK,
|
||||||
content=json_content.encode("utf-8"),
|
content=json_content.encode("utf-8"),
|
||||||
|
|
@ -335,7 +336,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
url=options.url,
|
url=options.url,
|
||||||
params=options.params,
|
params=options.params,
|
||||||
headers=options.headers or {},
|
headers=options.headers or {},
|
||||||
json=options.json_data,
|
json=convert_pydantic_to_json_value(body),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response = APIResponse(
|
response = APIResponse(
|
||||||
|
|
@ -384,7 +385,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
url=options.url,
|
url=options.url,
|
||||||
params=options.params,
|
params=options.params,
|
||||||
headers=options.headers or {},
|
headers=options.headers or {},
|
||||||
json=options.json_data,
|
json=convert_pydantic_to_json_value(body),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
from typing import Any, Dict, List, Set, Tuple
|
||||||
from typing import Any, Dict, List, Set
|
|
||||||
|
|
||||||
|
from llama_stack import logcat
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
|
@ -50,8 +50,6 @@ from llama_stack.providers.datatypes import (
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidProviderError(Exception):
|
class InvalidProviderError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
@ -104,60 +102,43 @@ class ProviderWithSpec(Provider):
|
||||||
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
||||||
|
|
||||||
|
|
||||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
|
||||||
async def resolve_impls(
|
async def resolve_impls(
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
provider_registry: ProviderRegistry,
|
provider_registry: ProviderRegistry,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Does two things:
|
Resolves provider implementations by:
|
||||||
- flatmaps, sorts and resolves the providers in dependency order
|
1. Validating and organizing providers.
|
||||||
- for each API, produces either a (local, passthrough or router) implementation
|
2. Sorting them in dependency order.
|
||||||
|
3. Instantiating them with required dependencies.
|
||||||
"""
|
"""
|
||||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||||
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||||
|
|
||||||
providers_with_specs = {}
|
providers_with_specs = validate_and_prepare_providers(
|
||||||
|
run_config, provider_registry, routing_table_apis, router_apis
|
||||||
for api_str, providers in run_config.providers.items():
|
)
|
||||||
api = Api(api_str)
|
|
||||||
if api in routing_table_apis:
|
|
||||||
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
|
||||||
|
|
||||||
specs = {}
|
|
||||||
for provider in providers:
|
|
||||||
if provider.provider_type not in provider_registry[api]:
|
|
||||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
|
||||||
|
|
||||||
p = provider_registry[api][provider.provider_type]
|
|
||||||
if p.deprecation_error:
|
|
||||||
log.error(p.deprecation_error, "red", attrs=["bold"])
|
|
||||||
raise InvalidProviderError(p.deprecation_error)
|
|
||||||
|
|
||||||
elif p.deprecation_warning:
|
|
||||||
log.warning(
|
|
||||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
|
||||||
)
|
|
||||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
|
||||||
spec = ProviderWithSpec(
|
|
||||||
spec=p,
|
|
||||||
**(provider.model_dump()),
|
|
||||||
)
|
|
||||||
specs[provider.provider_id] = spec
|
|
||||||
|
|
||||||
key = api_str if api not in router_apis else f"inner-{api_str}"
|
|
||||||
providers_with_specs[key] = specs
|
|
||||||
|
|
||||||
apis_to_serve = run_config.apis or set(
|
apis_to_serve = run_config.apis or set(
|
||||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
||||||
|
|
||||||
|
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
|
||||||
|
|
||||||
|
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||||
|
|
||||||
|
|
||||||
|
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||||
|
"""Generates specifications for automatically routed APIs."""
|
||||||
|
specs = {}
|
||||||
for info in builtin_automatically_routed_apis():
|
for info in builtin_automatically_routed_apis():
|
||||||
if info.router_api.value not in apis_to_serve:
|
if info.router_api.value not in apis_to_serve:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
providers_with_specs[info.routing_table_api.value] = {
|
specs[info.routing_table_api.value] = {
|
||||||
"__builtin__": ProviderWithSpec(
|
"__builtin__": ProviderWithSpec(
|
||||||
provider_id="__routing_table__",
|
provider_id="__routing_table__",
|
||||||
provider_type="__routing_table__",
|
provider_type="__routing_table__",
|
||||||
|
|
@ -167,12 +148,12 @@ async def resolve_impls(
|
||||||
router_api=info.router_api,
|
router_api=info.router_api,
|
||||||
module="llama_stack.distribution.routers",
|
module="llama_stack.distribution.routers",
|
||||||
api_dependencies=[],
|
api_dependencies=[],
|
||||||
deps__=([f"inner-{info.router_api.value}"]),
|
deps__=[f"inner-{info.router_api.value}"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
providers_with_specs[info.router_api.value] = {
|
specs[info.router_api.value] = {
|
||||||
"__builtin__": ProviderWithSpec(
|
"__builtin__": ProviderWithSpec(
|
||||||
provider_id="__autorouted__",
|
provider_id="__autorouted__",
|
||||||
provider_type="__autorouted__",
|
provider_type="__autorouted__",
|
||||||
|
|
@ -182,12 +163,67 @@ async def resolve_impls(
|
||||||
module="llama_stack.distribution.routers",
|
module="llama_stack.distribution.routers",
|
||||||
routing_table_api=info.routing_table_api,
|
routing_table_api=info.routing_table_api,
|
||||||
api_dependencies=[info.routing_table_api],
|
api_dependencies=[info.routing_table_api],
|
||||||
deps__=([info.routing_table_api.value]),
|
deps__=[info.routing_table_api.value],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
return specs
|
||||||
|
|
||||||
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
|
|
||||||
|
def validate_and_prepare_providers(
|
||||||
|
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
|
||||||
|
) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||||
|
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
||||||
|
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
|
||||||
|
|
||||||
|
for api_str, providers in run_config.providers.items():
|
||||||
|
api = Api(api_str)
|
||||||
|
if api in routing_table_apis:
|
||||||
|
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
||||||
|
|
||||||
|
specs = {}
|
||||||
|
for provider in providers:
|
||||||
|
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||||
|
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||||
|
continue
|
||||||
|
|
||||||
|
validate_provider(provider, api, provider_registry)
|
||||||
|
p = provider_registry[api][provider.provider_type]
|
||||||
|
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||||
|
spec = ProviderWithSpec(spec=p, **provider.model_dump())
|
||||||
|
specs[provider.provider_id] = spec
|
||||||
|
|
||||||
|
key = api_str if api not in router_apis else f"inner-{api_str}"
|
||||||
|
providers_with_specs[key] = specs
|
||||||
|
|
||||||
|
return providers_with_specs
|
||||||
|
|
||||||
|
|
||||||
|
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
|
||||||
|
"""Validates if the provider is allowed and handles deprecations."""
|
||||||
|
if provider.provider_type not in provider_registry[api]:
|
||||||
|
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||||
|
|
||||||
|
p = provider_registry[api][provider.provider_type]
|
||||||
|
if p.deprecation_error:
|
||||||
|
logcat.error("core", p.deprecation_error)
|
||||||
|
raise InvalidProviderError(p.deprecation_error)
|
||||||
|
elif p.deprecation_warning:
|
||||||
|
logcat.warning(
|
||||||
|
"core",
|
||||||
|
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sort_providers_by_deps(
|
||||||
|
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
||||||
|
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||||
|
"""Sorts providers based on their dependencies."""
|
||||||
|
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
|
||||||
|
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append built-in "inspect" provider
|
||||||
apis = [x[1].spec.api for x in sorted_providers]
|
apis = [x[1].spec.api for x in sorted_providers]
|
||||||
sorted_providers.append(
|
sorted_providers.append(
|
||||||
(
|
(
|
||||||
|
|
@ -195,28 +231,31 @@ async def resolve_impls(
|
||||||
ProviderWithSpec(
|
ProviderWithSpec(
|
||||||
provider_id="__builtin__",
|
provider_id="__builtin__",
|
||||||
provider_type="__builtin__",
|
provider_type="__builtin__",
|
||||||
config={
|
config={"run_config": run_config.model_dump()},
|
||||||
"run_config": run_config.dict(),
|
|
||||||
},
|
|
||||||
spec=InlineProviderSpec(
|
spec=InlineProviderSpec(
|
||||||
api=Api.inspect,
|
api=Api.inspect,
|
||||||
provider_type="__builtin__",
|
provider_type="__builtin__",
|
||||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||||
module="llama_stack.distribution.inspect",
|
module="llama_stack.distribution.inspect",
|
||||||
api_dependencies=apis,
|
api_dependencies=apis,
|
||||||
deps__=([x.value for x in apis]),
|
deps__=[x.value for x in apis],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"Resolved {len(sorted_providers)} providers")
|
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
log.info(f" {api_str} => {provider.provider_id}")
|
logcat.debug("core", f" {api_str} => {provider.provider_id}")
|
||||||
log.info("")
|
return sorted_providers
|
||||||
|
|
||||||
impls = {}
|
|
||||||
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
async def instantiate_providers(
|
||||||
|
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
|
||||||
|
) -> Dict:
|
||||||
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
|
impls: Dict[Api, Any] = {}
|
||||||
|
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||||
for a in provider.spec.optional_api_dependencies:
|
for a in provider.spec.optional_api_dependencies:
|
||||||
|
|
@ -227,14 +266,9 @@ async def resolve_impls(
|
||||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||||
|
|
||||||
impl = await instantiate_provider(
|
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
||||||
provider,
|
|
||||||
deps,
|
if api_str.startswith("inner-"):
|
||||||
inner_impls,
|
|
||||||
dist_registry,
|
|
||||||
)
|
|
||||||
# TODO: ugh slightly redesign this shady looking code
|
|
||||||
if "inner-" in api_str:
|
|
||||||
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
||||||
else:
|
else:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
@ -245,7 +279,7 @@ async def resolve_impls(
|
||||||
|
|
||||||
def topological_sort(
|
def topological_sort(
|
||||||
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
||||||
) -> List[ProviderWithSpec]:
|
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||||
def dfs(kv, visited: Set[str], stack: List[str]):
|
def dfs(kv, visited: Set[str], stack: List[str]):
|
||||||
api_str, providers = kv
|
api_str, providers = kv
|
||||||
visited.add(api_str)
|
visited.add(api_str)
|
||||||
|
|
@ -261,8 +295,8 @@ def topological_sort(
|
||||||
|
|
||||||
stack.append(api_str)
|
stack.append(api_str)
|
||||||
|
|
||||||
visited = set()
|
visited: Set[str] = set()
|
||||||
stack = []
|
stack: List[str] = []
|
||||||
|
|
||||||
for api_str, providers in providers_with_specs.items():
|
for api_str, providers in providers_with_specs.items():
|
||||||
if api_str not in visited:
|
if api_str not in visited:
|
||||||
|
|
@ -272,13 +306,14 @@ def topological_sort(
|
||||||
for api_str in stack:
|
for api_str in stack:
|
||||||
for provider in providers_with_specs[api_str]:
|
for provider in providers_with_specs[api_str]:
|
||||||
flattened.append((api_str, provider))
|
flattened.append((api_str, provider))
|
||||||
|
|
||||||
return flattened
|
return flattened
|
||||||
|
|
||||||
|
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
async def instantiate_provider(
|
async def instantiate_provider(
|
||||||
provider: ProviderWithSpec,
|
provider: ProviderWithSpec,
|
||||||
deps: Dict[str, Any],
|
deps: Dict[Api, Any],
|
||||||
inner_impls: Dict[str, Any],
|
inner_impls: Dict[str, Any],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
):
|
):
|
||||||
|
|
@ -286,8 +321,10 @@ async def instantiate_provider(
|
||||||
additional_protocols = additional_protocols_map()
|
additional_protocols = additional_protocols_map()
|
||||||
|
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
module = importlib.import_module(provider_spec.module)
|
if not hasattr(provider_spec, "module"):
|
||||||
|
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||||
|
|
||||||
|
module = importlib.import_module(provider_spec.module)
|
||||||
args = []
|
args = []
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
|
@ -350,7 +387,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
obj_params = set(obj_sig.parameters)
|
obj_params = set(obj_sig.parameters)
|
||||||
obj_params.discard("self")
|
obj_params.discard("self")
|
||||||
if not (proto_params <= obj_params):
|
if not (proto_params <= obj_params):
|
||||||
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||||
missing_methods.append((name, "signature_mismatch"))
|
missing_methods.append((name, "signature_mismatch"))
|
||||||
else:
|
else:
|
||||||
# Check if the method is actually implemented in the class
|
# Check if the method is actually implemented in the class
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
|
from llama_stack import logcat
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
|
@ -52,7 +53,6 @@ from llama_stack.apis.tools import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import get_default_tool_prompt_format
|
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
class VectorIORouter(VectorIO):
|
||||||
|
|
@ -62,12 +62,15 @@ class VectorIORouter(VectorIO):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing VectorIORouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "VectorIORouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "VectorIORouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
|
|
@ -78,6 +81,10 @@ class VectorIORouter(VectorIO):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}",
|
||||||
|
)
|
||||||
await self.routing_table.register_vector_db(
|
await self.routing_table.register_vector_db(
|
||||||
vector_db_id,
|
vector_db_id,
|
||||||
embedding_model,
|
embedding_model,
|
||||||
|
|
@ -92,6 +99,10 @@ class VectorIORouter(VectorIO):
|
||||||
chunks: List[Chunk],
|
chunks: List[Chunk],
|
||||||
ttl_seconds: Optional[int] = None,
|
ttl_seconds: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
|
)
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
|
|
@ -100,6 +111,7 @@ class VectorIORouter(VectorIO):
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
|
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -110,12 +122,15 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing InferenceRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "InferenceRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "InferenceRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(
|
async def register_model(
|
||||||
|
|
@ -126,6 +141,10 @@ class InferenceRouter(Inference):
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||||
|
)
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
|
@ -141,6 +160,10 @@ class InferenceRouter(Inference):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||||
|
)
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
|
@ -159,8 +182,6 @@ class InferenceRouter(Inference):
|
||||||
params["tool_prompt_format"] = tool_prompt_format
|
params["tool_prompt_format"] = tool_prompt_format
|
||||||
tool_config = ToolConfig(**params)
|
tool_config = ToolConfig(**params)
|
||||||
|
|
||||||
tool_config.tool_prompt_format = tool_config.tool_prompt_format or get_default_tool_prompt_format(model_id)
|
|
||||||
|
|
||||||
tools = tools or []
|
tools = tools or []
|
||||||
if tool_config.tool_choice == ToolChoice.none:
|
if tool_config.tool_choice == ToolChoice.none:
|
||||||
tools = []
|
tools = []
|
||||||
|
|
@ -201,6 +222,10 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||||
|
)
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
|
@ -228,6 +253,7 @@ class InferenceRouter(Inference):
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: Optional[int] = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
|
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
|
@ -247,12 +273,15 @@ class SafetyRouter(Safety):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing SafetyRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "SafetyRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "SafetyRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
|
|
@ -262,6 +291,7 @@ class SafetyRouter(Safety):
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[Dict[str, Any]] = None,
|
||||||
) -> Shield:
|
) -> Shield:
|
||||||
|
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
|
|
@ -270,6 +300,7 @@ class SafetyRouter(Safety):
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: Dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
|
||||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||||
shield_id=shield_id,
|
shield_id=shield_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
@ -282,12 +313,15 @@ class DatasetIORouter(DatasetIO):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing DatasetIORouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "DatasetIORouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "DatasetIORouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def get_rows_paginated(
|
async def get_rows_paginated(
|
||||||
|
|
@ -297,6 +331,10 @@ class DatasetIORouter(DatasetIO):
|
||||||
page_token: Optional[str] = None,
|
page_token: Optional[str] = None,
|
||||||
filter_condition: Optional[str] = None,
|
filter_condition: Optional[str] = None,
|
||||||
) -> PaginatedRowsResult:
|
) -> PaginatedRowsResult:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
|
||||||
|
)
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=rows_in_page,
|
rows_in_page=rows_in_page,
|
||||||
|
|
@ -305,6 +343,7 @@ class DatasetIORouter(DatasetIO):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
|
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows=rows,
|
rows=rows,
|
||||||
|
|
@ -316,12 +355,15 @@ class ScoringRouter(Scoring):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing ScoringRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "ScoringRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "ScoringRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
|
|
@ -330,6 +372,7 @@ class ScoringRouter(Scoring):
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
|
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
res = {}
|
res = {}
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||||
|
|
@ -350,6 +393,10 @@ class ScoringRouter(Scoring):
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions",
|
||||||
|
)
|
||||||
res = {}
|
res = {}
|
||||||
# look up and map each scoring function to its provider impl
|
# look up and map each scoring function to its provider impl
|
||||||
for fn_identifier in scoring_functions.keys():
|
for fn_identifier in scoring_functions.keys():
|
||||||
|
|
@ -367,22 +414,26 @@ class EvalRouter(Eval):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing EvalRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "EvalRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "EvalRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
task_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
|
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
task_config=task_config,
|
benchmark_config=benchmark_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
|
|
@ -390,13 +441,14 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
task_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
|
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
task_config=task_config,
|
benchmark_config=benchmark_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def job_status(
|
async def job_status(
|
||||||
|
|
@ -404,6 +456,7 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> Optional[JobStatus]:
|
) -> Optional[JobStatus]:
|
||||||
|
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||||
|
|
||||||
async def job_cancel(
|
async def job_cancel(
|
||||||
|
|
@ -411,6 +464,7 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||||
benchmark_id,
|
benchmark_id,
|
||||||
job_id,
|
job_id,
|
||||||
|
|
@ -421,6 +475,7 @@ class EvalRouter(Eval):
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
|
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||||
benchmark_id,
|
benchmark_id,
|
||||||
job_id,
|
job_id,
|
||||||
|
|
@ -433,6 +488,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
|
|
@ -441,7 +497,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
vector_db_ids: List[str],
|
vector_db_ids: List[str],
|
||||||
query_config: Optional[RAGQueryConfig] = None,
|
query_config: Optional[RAGQueryConfig] = None,
|
||||||
) -> RAGQueryResult:
|
) -> RAGQueryResult:
|
||||||
return await self.routing_table.get_provider_impl("query_from_memory").query(
|
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||||
|
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||||
content, vector_db_ids, query_config
|
content, vector_db_ids, query_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -451,6 +508,10 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug(
|
||||||
|
"core",
|
||||||
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
|
||||||
|
)
|
||||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||||
documents, vector_db_id, chunk_size_in_tokens
|
documents, vector_db_id, chunk_size_in_tokens
|
||||||
)
|
)
|
||||||
|
|
@ -459,6 +520,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logcat.debug("core", "Initializing ToolRuntimeRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||||
|
|
@ -467,12 +529,15 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
logcat.debug("core", "ToolRuntimeRouter.initialize")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
logcat.debug("core", "ToolRuntimeRouter.shutdown")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||||
|
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
|
@ -481,4 +546,5 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
) -> List[ToolDef]:
|
) -> List[ToolDef]:
|
||||||
|
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||||
|
|
|
||||||
|
|
@ -318,14 +318,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
)
|
)
|
||||||
model = await self.get_object_by_identifier("model", embedding_model)
|
model = await self.get_object_by_identifier("model", embedding_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
if embedding_model == "all-MiniLM-L6-v2":
|
raise ValueError(f"Model {embedding_model} not found")
|
||||||
raise ValueError(
|
|
||||||
"Embeddings are now served via Inference providers. "
|
|
||||||
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
|
|
||||||
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Model {embedding_model} not found")
|
|
||||||
if model.model_type != ModelType.embedding:
|
if model.model_type != ModelType.embedding:
|
||||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||||
if "embedding_dimension" not in model.metadata:
|
if "embedding_dimension" not in model.metadata:
|
||||||
|
|
|
||||||
|
|
@ -26,9 +26,9 @@ from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack import logcat
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
|
|
@ -55,7 +55,7 @@ from .endpoints import get_all_api_endpoints
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
|
||||||
logger = logging.getLogger(__name__)
|
logcat.init()
|
||||||
|
|
||||||
|
|
||||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||||
|
|
@ -142,23 +142,23 @@ def handle_signal(app, signum, _) -> None:
|
||||||
not block the current execution.
|
not block the current execution.
|
||||||
"""
|
"""
|
||||||
signame = signal.Signals(signum).name
|
signame = signal.Signals(signum).name
|
||||||
print(f"Received signal {signame} ({signum}). Exiting gracefully...")
|
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
|
||||||
|
|
||||||
async def shutdown():
|
async def shutdown():
|
||||||
try:
|
try:
|
||||||
# Gracefully shut down implementations
|
# Gracefully shut down implementations
|
||||||
for impl in app.__llama_stack_impls__.values():
|
for impl in app.__llama_stack_impls__.values():
|
||||||
impl_name = impl.__class__.__name__
|
impl_name = impl.__class__.__name__
|
||||||
logger.info("Shutting down %s", impl_name)
|
logcat.info("server", f"Shutting down {impl_name}")
|
||||||
try:
|
try:
|
||||||
if hasattr(impl, "shutdown"):
|
if hasattr(impl, "shutdown"):
|
||||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||||
else:
|
else:
|
||||||
logger.warning("No shutdown method for %s", impl_name)
|
logcat.warning("server", f"No shutdown method for {impl_name}")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
logcat.exception("server", f"Shutdown timeout for {impl_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
|
||||||
|
|
||||||
# Gather all running tasks
|
# Gather all running tasks
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
@ -172,7 +172,7 @@ def handle_signal(app, signum, _) -> None:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.exception("Timeout while waiting for tasks to finish")
|
logcat.exception("server", "Timeout while waiting for tasks to finish")
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -184,9 +184,9 @@ def handle_signal(app, signum, _) -> None:
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
print("Starting up")
|
logcat.info("server", "Starting up")
|
||||||
yield
|
yield
|
||||||
print("Shutting down")
|
logcat.info("server", "Shutting down")
|
||||||
for impl in app.__llama_stack_impls__.values():
|
for impl in app.__llama_stack_impls__.values():
|
||||||
await impl.shutdown()
|
await impl.shutdown()
|
||||||
|
|
||||||
|
|
@ -209,10 +209,10 @@ async def sse_generator(event_gen):
|
||||||
yield create_sse_event(item)
|
yield create_sse_event(item)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
print("Generator cancelled")
|
logcat.info("server", "Generator cancelled")
|
||||||
await event_gen.aclose()
|
await event_gen.aclose()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exception(e)
|
logcat.exception("server", "Error in sse_generator")
|
||||||
yield create_sse_event(
|
yield create_sse_event(
|
||||||
{
|
{
|
||||||
"error": {
|
"error": {
|
||||||
|
|
@ -234,7 +234,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
value = func(**kwargs)
|
value = func(**kwargs)
|
||||||
return await maybe_await(value)
|
return await maybe_await(value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exception(e)
|
logcat.exception("server", f"Error in {func.__name__}")
|
||||||
raise translate_exception(e) from e
|
raise translate_exception(e) from e
|
||||||
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
@ -313,6 +313,8 @@ class ClientVersionMiddleware:
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
logcat.init()
|
||||||
|
|
||||||
"""Start the LlamaStack server."""
|
"""Start the LlamaStack server."""
|
||||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -352,10 +354,10 @@ def main():
|
||||||
for env_pair in args.env:
|
for env_pair in args.env:
|
||||||
try:
|
try:
|
||||||
key, value = validate_env_pair(env_pair)
|
key, value = validate_env_pair(env_pair)
|
||||||
print(f"Setting CLI environment variable {key} => {value}")
|
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f"Error: {str(e)}")
|
logcat.error("server", f"Error: {str(e)}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if args.yaml_config:
|
if args.yaml_config:
|
||||||
|
|
@ -363,12 +365,12 @@ def main():
|
||||||
config_file = Path(args.yaml_config)
|
config_file = Path(args.yaml_config)
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise ValueError(f"Config file {config_file} does not exist")
|
raise ValueError(f"Config file {config_file} does not exist")
|
||||||
print(f"Using config file: {config_file}")
|
logcat.info("server", f"Using config file: {config_file}")
|
||||||
elif args.template:
|
elif args.template:
|
||||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
raise ValueError(f"Template {args.template} does not exist")
|
raise ValueError(f"Template {args.template} does not exist")
|
||||||
print(f"Using template {args.template} config file: {config_file}")
|
logcat.info("server", f"Using template {args.template} config file: {config_file}")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either --yaml-config or --template must be provided")
|
raise ValueError("Either --yaml-config or --template must be provided")
|
||||||
|
|
||||||
|
|
@ -376,9 +378,10 @@ def main():
|
||||||
config = replace_env_vars(yaml.safe_load(fp))
|
config = replace_env_vars(yaml.safe_load(fp))
|
||||||
config = StackRunConfig(**config)
|
config = StackRunConfig(**config)
|
||||||
|
|
||||||
print("Run configuration:")
|
logcat.info("server", "Run configuration:")
|
||||||
safe_config = redact_sensitive_fields(config.model_dump())
|
safe_config = redact_sensitive_fields(config.model_dump())
|
||||||
print(yaml.dump(safe_config, indent=2))
|
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
|
||||||
|
logcat.info("server", log_line)
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
app.add_middleware(TracingMiddleware)
|
app.add_middleware(TracingMiddleware)
|
||||||
|
|
@ -387,7 +390,8 @@ def main():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
except InvalidProviderError:
|
except InvalidProviderError as e:
|
||||||
|
logcat.error("server", f"Error: {str(e)}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
|
|
@ -432,11 +436,8 @@ def main():
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
|
logcat.debug("server", f"serving APIs: {apis_to_serve}")
|
||||||
for endpoint in endpoints:
|
|
||||||
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
|
|
||||||
|
|
||||||
print("")
|
|
||||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
|
||||||
|
|
@ -462,10 +463,10 @@ def main():
|
||||||
"ssl_keyfile": keyfile,
|
"ssl_keyfile": keyfile,
|
||||||
"ssl_certfile": certfile,
|
"ssl_certfile": certfile,
|
||||||
}
|
}
|
||||||
print(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||||
|
|
||||||
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
||||||
print(f"Listening on {listen_host}:{port}")
|
logcat.info("server", f"Listening on {listen_host}:{port}")
|
||||||
|
|
||||||
uvicorn_config = {
|
uvicorn_config = {
|
||||||
"app": app,
|
"app": app,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
@ -13,6 +12,7 @@ from typing import Any, Dict, Optional
|
||||||
import yaml
|
import yaml
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
from llama_stack import logcat
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.batch_inference import BatchInference
|
from llama_stack.apis.batch_inference import BatchInference
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
|
|
@ -39,8 +39,6 @@ from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||||
from llama_stack.distribution.store.registry import create_dist_registry
|
from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaStack(
|
class LlamaStack(
|
||||||
VectorDBs,
|
VectorDBs,
|
||||||
|
|
@ -101,12 +99,11 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||||
objects_to_process = response.data if hasattr(response, "data") else response
|
objects_to_process = response.data if hasattr(response, "data") else response
|
||||||
|
|
||||||
for obj in objects_to_process:
|
for obj in objects_to_process:
|
||||||
log.info(
|
logcat.debug(
|
||||||
|
"core",
|
||||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info("")
|
|
||||||
|
|
||||||
|
|
||||||
class EnvVarError(Exception):
|
class EnvVarError(Exception):
|
||||||
def __init__(self, var_name: str, path: str = ""):
|
def __init__(self, var_name: str, path: str = ""):
|
||||||
|
|
@ -155,18 +152,34 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
elif isinstance(config, str):
|
elif isinstance(config, str):
|
||||||
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"
|
# Updated pattern to support both default values (:) and conditional values (+)
|
||||||
|
pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}"
|
||||||
|
|
||||||
def get_env_var(match):
|
def get_env_var(match):
|
||||||
env_var = match.group(1)
|
env_var = match.group(1)
|
||||||
default_val = match.group(2)
|
operator = match.group(2) # ':' for default, '+' for conditional
|
||||||
|
value_expr = match.group(3)
|
||||||
|
|
||||||
value = os.environ.get(env_var)
|
env_value = os.environ.get(env_var)
|
||||||
if not value:
|
|
||||||
if default_val is None:
|
if operator == ":": # Default value syntax: ${env.FOO:default}
|
||||||
raise EnvVarError(env_var, path)
|
if not env_value:
|
||||||
|
if value_expr is None:
|
||||||
|
raise EnvVarError(env_var, path)
|
||||||
|
else:
|
||||||
|
value = value_expr
|
||||||
else:
|
else:
|
||||||
value = default_val
|
value = env_value
|
||||||
|
elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set}
|
||||||
|
if env_value:
|
||||||
|
value = value_expr
|
||||||
|
else:
|
||||||
|
# If env var is not set, return empty string for the conditional case
|
||||||
|
value = ""
|
||||||
|
else: # No operator case: ${env.FOO}
|
||||||
|
if not env_value:
|
||||||
|
raise EnvVarError(env_var, path)
|
||||||
|
value = env_value
|
||||||
|
|
||||||
# expand "~" from the values
|
# expand "~" from the values
|
||||||
return os.path.expanduser(value)
|
return os.path.expanduser(value)
|
||||||
|
|
|
||||||
|
|
@ -98,15 +98,20 @@ case "$env_type" in
|
||||||
*)
|
*)
|
||||||
esac
|
esac
|
||||||
|
|
||||||
set -x
|
|
||||||
|
|
||||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||||
|
set -x
|
||||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||||
--yaml-config "$yaml_config" \
|
--yaml-config "$yaml_config" \
|
||||||
--port "$port" \
|
--port "$port" \
|
||||||
$env_vars \
|
$env_vars \
|
||||||
$other_args
|
$other_args
|
||||||
elif [[ "$env_type" == "container" ]]; then
|
elif [[ "$env_type" == "container" ]]; then
|
||||||
|
# Check if container command is available
|
||||||
|
if ! is_command_available $CONTAINER_BINARY; then
|
||||||
|
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
|
if is_command_available selinuxenabled &> /dev/null && selinuxenabled; then
|
||||||
# Disable SELinux labels
|
# Disable SELinux labels
|
||||||
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
|
||||||
|
|
@ -136,6 +141,8 @@ elif [[ "$env_type" == "container" ]]; then
|
||||||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
set -x
|
||||||
|
|
||||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||||
-p $port:$port \
|
-p $port:$port \
|
||||||
$env_vars \
|
$env_vars \
|
||||||
|
|
|
||||||
|
|
@ -1,72 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
RED='\033[0;31m'
|
|
||||||
NC='\033[0m' # No Color
|
|
||||||
|
|
||||||
error_handler() {
|
|
||||||
echo "Error occurred in script at line: ${1}" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
trap 'error_handler ${LINENO}' ERR
|
|
||||||
|
|
||||||
if [ $# -lt 3 ]; then
|
|
||||||
echo "Usage: $0 <venv_path> <yaml_config> <port> <script_args...>"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
venv_path="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
yaml_config="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
port="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
# Initialize env_vars as an empty array
|
|
||||||
env_vars=""
|
|
||||||
other_args=""
|
|
||||||
# Process environment variables from --env arguments
|
|
||||||
while [[ $# -gt 0 ]]; do
|
|
||||||
case "$1" in
|
|
||||||
--env)
|
|
||||||
|
|
||||||
if [[ -n "$2" ]]; then
|
|
||||||
env_vars="$env_vars --env $2"
|
|
||||||
shift 2
|
|
||||||
else
|
|
||||||
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
other_args="$other_args $1"
|
|
||||||
shift
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "Using virtual environment: $venv_path"
|
|
||||||
# Activate virtual environment
|
|
||||||
if [ ! -d "$venv_path" ]; then
|
|
||||||
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
source "$venv_path/bin/activate"
|
|
||||||
|
|
||||||
set -x
|
|
||||||
python -m llama_stack.distribution.server.server \
|
|
||||||
--yaml-config "$yaml_config" \
|
|
||||||
--port "$port" \
|
|
||||||
$env_vars \
|
|
||||||
$other_args
|
|
||||||
|
|
@ -33,7 +33,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v7"
|
KEY_VERSION = "v8"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -212,7 +212,7 @@ def run_evaluation_3():
|
||||||
benchmark_id=selected_benchmark,
|
benchmark_id=selected_benchmark,
|
||||||
input_rows=[r],
|
input_rows=[r],
|
||||||
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
|
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
|
||||||
task_config=benchmark_config,
|
benchmark_config=benchmark_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
for k in r.keys():
|
for k in r.keys():
|
||||||
|
|
|
||||||
|
|
@ -132,9 +132,9 @@ def rag_chat_page():
|
||||||
},
|
},
|
||||||
toolgroups=[
|
toolgroups=[
|
||||||
dict(
|
dict(
|
||||||
name="builtin::rag",
|
name="builtin::rag/knowledge_search",
|
||||||
args={
|
args={
|
||||||
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
"vector_db_ids": list(selected_vector_dbs),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
|
||||||
envs = conda_env_info["envs"]
|
envs = conda_env_info["envs"]
|
||||||
for envpath in envs:
|
for envpath in envs:
|
||||||
if envpath.endswith(env_name):
|
if os.path.basename(envpath) == env_name:
|
||||||
return envpath
|
return envpath
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
204
llama_stack/logcat.py
Normal file
204
llama_stack/logcat.py
Normal file
|
|
@ -0,0 +1,204 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Category-based logging utility for llama-stack.
|
||||||
|
|
||||||
|
This module provides a wrapper over the standard Python logging module that supports
|
||||||
|
categorized logging with environment variable control.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from llama_stack import logcat
|
||||||
|
logcat.info("server", "Starting up...")
|
||||||
|
logcat.debug("inference", "Processing request...")
|
||||||
|
|
||||||
|
Environment variable:
|
||||||
|
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
|
||||||
|
Example: "server=debug;inference=warning"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
# ANSI color codes for terminal output
|
||||||
|
COLORS = {
|
||||||
|
"RESET": "\033[0m",
|
||||||
|
"DEBUG": "\033[36m", # Cyan
|
||||||
|
"INFO": "\033[32m", # Green
|
||||||
|
"WARNING": "\033[33m", # Yellow
|
||||||
|
"ERROR": "\033[31m", # Red
|
||||||
|
"CRITICAL": "\033[35m", # Magenta
|
||||||
|
"DIM": "\033[2m", # Dimmed text
|
||||||
|
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
|
||||||
|
}
|
||||||
|
|
||||||
|
# Static list of valid categories representing various parts of the Llama Stack
|
||||||
|
# server codebase
|
||||||
|
CATEGORIES = [
|
||||||
|
"core",
|
||||||
|
"server",
|
||||||
|
"router",
|
||||||
|
"inference",
|
||||||
|
"agents",
|
||||||
|
"safety",
|
||||||
|
"eval",
|
||||||
|
"tools",
|
||||||
|
"client",
|
||||||
|
]
|
||||||
|
|
||||||
|
_logger = logging.getLogger("llama_stack")
|
||||||
|
_logger.propagate = False
|
||||||
|
|
||||||
|
_default_level = logging.INFO
|
||||||
|
|
||||||
|
# Category-level mapping (can be modified by environment variables)
|
||||||
|
_category_levels: Dict[str, int] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class TerminalStreamHandler(logging.StreamHandler):
|
||||||
|
def __init__(self, stream=None):
|
||||||
|
super().__init__(stream)
|
||||||
|
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
record.is_tty = self.is_tty
|
||||||
|
return super().format(record)
|
||||||
|
|
||||||
|
|
||||||
|
class ColoredFormatter(logging.Formatter):
|
||||||
|
"""Custom formatter with colors and fixed-width level names"""
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
levelname = record.levelname
|
||||||
|
# Use only time with milliseconds, not date
|
||||||
|
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
|
||||||
|
|
||||||
|
file_info = f"{record.filename}:{record.lineno}"
|
||||||
|
|
||||||
|
# Get category from extra if available
|
||||||
|
category = getattr(record, "category", None)
|
||||||
|
msg = record.getMessage()
|
||||||
|
|
||||||
|
if getattr(record, "is_tty", False):
|
||||||
|
color = COLORS.get(levelname, COLORS["RESET"])
|
||||||
|
if category:
|
||||||
|
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
|
||||||
|
formatted_msg = (
|
||||||
|
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
|
||||||
|
f"{file_info:<20} {category_formatted}{msg}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
formatted_msg = (
|
||||||
|
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
|
||||||
|
f"{file_info:<20} {msg}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if category:
|
||||||
|
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
|
||||||
|
else:
|
||||||
|
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
|
||||||
|
|
||||||
|
return formatted_msg
|
||||||
|
|
||||||
|
|
||||||
|
def init(default_level: int = logging.INFO) -> None:
|
||||||
|
global _default_level, _category_levels, _logger
|
||||||
|
|
||||||
|
_default_level = default_level
|
||||||
|
|
||||||
|
_logger.setLevel(logging.DEBUG)
|
||||||
|
_logger.handlers = [] # Clear existing handlers
|
||||||
|
|
||||||
|
# Add our custom handler with the colored formatter
|
||||||
|
handler = TerminalStreamHandler()
|
||||||
|
formatter = ColoredFormatter()
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
_logger.addHandler(handler)
|
||||||
|
|
||||||
|
for category in CATEGORIES:
|
||||||
|
_category_levels[category] = default_level
|
||||||
|
|
||||||
|
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
|
||||||
|
if env_config:
|
||||||
|
for pair in env_config.split(";"):
|
||||||
|
if not pair.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
category, level = pair.split("=", 1)
|
||||||
|
category = category.strip().lower()
|
||||||
|
level = level.strip().lower()
|
||||||
|
|
||||||
|
level_value = {
|
||||||
|
"debug": logging.DEBUG,
|
||||||
|
"info": logging.INFO,
|
||||||
|
"warning": logging.WARNING,
|
||||||
|
"warn": logging.WARNING,
|
||||||
|
"error": logging.ERROR,
|
||||||
|
"critical": logging.CRITICAL,
|
||||||
|
}.get(level)
|
||||||
|
|
||||||
|
if level_value is None:
|
||||||
|
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if category == "all":
|
||||||
|
for cat in CATEGORIES:
|
||||||
|
_category_levels[cat] = level_value
|
||||||
|
else:
|
||||||
|
if category in CATEGORIES:
|
||||||
|
_category_levels[category] = level_value
|
||||||
|
else:
|
||||||
|
_logger.warning(f"Unknown logging category: {category}")
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
_logger.warning(f"Invalid logging configuration: {pair}")
|
||||||
|
|
||||||
|
|
||||||
|
def _should_log(level: int, category: str) -> bool:
|
||||||
|
category = category.lower()
|
||||||
|
if category not in _category_levels:
|
||||||
|
return False
|
||||||
|
category_level = _category_levels[category]
|
||||||
|
return level >= category_level
|
||||||
|
|
||||||
|
|
||||||
|
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
if _should_log(level, category):
|
||||||
|
kwargs.setdefault("extra", {})["category"] = category.lower()
|
||||||
|
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def debug(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def info(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
_log(logging.INFO, "info", category, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def warning(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def warn(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
warning(category, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def error(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def critical(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def exception(category: str, msg: str, *args, **kwargs) -> None:
|
||||||
|
if _should_log(logging.ERROR, category):
|
||||||
|
kwargs.setdefault("extra", {})["category"] = category.lower()
|
||||||
|
_logger.exception(msg, *args, stacklevel=2, **kwargs)
|
||||||
|
|
@ -11,16 +11,128 @@
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from io import BytesIO
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
# import all for backwards compatibility
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||||
from llama_models.datatypes import * # noqa: F403
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
# The goal is that these set of types are relevant for all Llama models.
|
||||||
|
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||||
|
# the llama3 series of models.
|
||||||
|
|
||||||
|
|
||||||
|
class Role(Enum):
|
||||||
|
system = "system"
|
||||||
|
user = "user"
|
||||||
|
assistant = "assistant"
|
||||||
|
tool = "tool"
|
||||||
|
|
||||||
|
|
||||||
|
class BuiltinTool(Enum):
|
||||||
|
brave_search = "brave_search"
|
||||||
|
wolfram_alpha = "wolfram_alpha"
|
||||||
|
photogen = "photogen"
|
||||||
|
code_interpreter = "code_interpreter"
|
||||||
|
|
||||||
|
|
||||||
|
Primitive = Union[str, int, float, bool, None]
|
||||||
|
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
call_id: str
|
||||||
|
tool_name: Union[BuiltinTool, str]
|
||||||
|
arguments: Dict[str, RecursiveType]
|
||||||
|
|
||||||
|
@field_validator("tool_name", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
return BuiltinTool(v)
|
||||||
|
except ValueError:
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ToolPromptFormat(Enum):
|
||||||
|
"""Prompt format for calling custom / zero shot tools.
|
||||||
|
|
||||||
|
:cvar json: JSON format for calling tools. It takes the form:
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function" : {
|
||||||
|
"name": "function_name",
|
||||||
|
"description": "function_description",
|
||||||
|
"parameters": {...}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
:cvar function_tag: Function tag format, pseudo-XML. This looks like:
|
||||||
|
<function=function_name>(parameters)</function>
|
||||||
|
|
||||||
|
:cvar python_list: Python list. The output is a valid Python expression that can be
|
||||||
|
evaluated to a list. Each element in the list is a function call. Example:
|
||||||
|
["function_name(param1, param2)", "function_name(param1, param2)"]
|
||||||
|
"""
|
||||||
|
|
||||||
|
json = "json"
|
||||||
|
function_tag = "function_tag"
|
||||||
|
python_list = "python_list"
|
||||||
|
|
||||||
|
|
||||||
|
class StopReason(Enum):
|
||||||
|
end_of_turn = "end_of_turn"
|
||||||
|
end_of_message = "end_of_message"
|
||||||
|
out_of_tokens = "out_of_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
class RawMediaItem(BaseModel):
|
||||||
|
type: Literal["image"] = "image"
|
||||||
|
data: bytes | BytesIO
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
@field_serializer("data")
|
||||||
|
def serialize_data(self, data: Optional[bytes], _info):
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
return base64.b64encode(data).decode("utf-8")
|
||||||
|
|
||||||
|
@field_validator("data", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_data(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
return base64.b64decode(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class RawTextItem(BaseModel):
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
|
||||||
|
|
||||||
|
RawContent = str | RawContentItem | List[RawContentItem]
|
||||||
|
|
||||||
|
|
||||||
|
class RawMessage(BaseModel):
|
||||||
|
role: Literal["user"] | Literal["system"] | Literal["tool"] | Literal["assistant"]
|
||||||
|
content: RawContent
|
||||||
|
|
||||||
|
# This is for RAG but likely should be absorbed into content
|
||||||
|
context: Optional[RawContent] = None
|
||||||
|
|
||||||
|
# These are for the output message coming from the assistant
|
||||||
|
stop_reason: Optional[StopReason] = None
|
||||||
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
register_schema(ToolCall)
|
register_schema(ToolCall)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
282
llama_stack/models/llama/llama3/chat_format.py
Normal file
282
llama_stack/models/llama/llama3/chat_format.py
Normal file
|
|
@ -0,0 +1,282 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import io
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
RawContent,
|
||||||
|
RawMediaItem,
|
||||||
|
RawMessage,
|
||||||
|
RawTextItem,
|
||||||
|
Role,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
from .tool_utils import ToolUtils
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VisionInput:
|
||||||
|
mask: List[List[int]]
|
||||||
|
images: List[PIL_Image.Image]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMInput:
|
||||||
|
tokens: List[int]
|
||||||
|
vision: Optional[VisionInput] = None
|
||||||
|
|
||||||
|
|
||||||
|
def role_str(role: Role) -> str:
|
||||||
|
role_strs = {
|
||||||
|
Role.user: "user",
|
||||||
|
Role.system: "system",
|
||||||
|
Role.tool: "ipython", # special
|
||||||
|
Role.assistant: "assistant",
|
||||||
|
}
|
||||||
|
return role_strs[role]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatFormat:
|
||||||
|
possible_headers: Dict[Role, str]
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: Tokenizer):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
|
||||||
|
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
|
||||||
|
|
||||||
|
def _encode_header(self, role: str) -> List[int]:
|
||||||
|
tokens = []
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
||||||
|
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
|
||||||
|
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def encode_content(self, content: RawContent) -> LLMInput:
|
||||||
|
tokens, images = self._encode_content(content, bos=True)
|
||||||
|
return self._model_input_from_tokens_images(tokens, images)
|
||||||
|
|
||||||
|
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||||
|
tokens = []
|
||||||
|
images = []
|
||||||
|
|
||||||
|
added_bos = False
|
||||||
|
|
||||||
|
def _process(c):
|
||||||
|
nonlocal added_bos, bos
|
||||||
|
|
||||||
|
if isinstance(c, str) or isinstance(c, RawTextItem):
|
||||||
|
if isinstance(c, RawTextItem):
|
||||||
|
c = c.text
|
||||||
|
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
|
||||||
|
added_bos = True
|
||||||
|
|
||||||
|
elif isinstance(c, RawMediaItem):
|
||||||
|
bos = False if added_bos else bos
|
||||||
|
if bos:
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||||
|
added_bos = True
|
||||||
|
tokens.append(self.vision_token)
|
||||||
|
|
||||||
|
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||||
|
image = PIL_Image.open(bytes_io)
|
||||||
|
image = image.convert("RGB")
|
||||||
|
images.append(image)
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
for c in content:
|
||||||
|
_process(c)
|
||||||
|
else:
|
||||||
|
_process(content)
|
||||||
|
|
||||||
|
return tokens, images
|
||||||
|
|
||||||
|
def encode_message(
|
||||||
|
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||||
|
) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||||
|
tokens = self._encode_header(message.role)
|
||||||
|
images = []
|
||||||
|
|
||||||
|
def _process_content(c):
|
||||||
|
toks, imgs = self._encode_content(c)
|
||||||
|
tokens.extend(toks)
|
||||||
|
images.extend(imgs)
|
||||||
|
|
||||||
|
if (
|
||||||
|
message.role == "assistant"
|
||||||
|
and len(message.tool_calls) > 0
|
||||||
|
and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter
|
||||||
|
):
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])
|
||||||
|
|
||||||
|
_process_content(message.content)
|
||||||
|
|
||||||
|
if message.role == "user" and message.context is not None:
|
||||||
|
# This is RAG context; why is it here in the chat format? I don't think
|
||||||
|
# this is needed and can be moved upwards
|
||||||
|
_process_content("\n\n")
|
||||||
|
_process_content(message.context)
|
||||||
|
|
||||||
|
if message.role == "assistant":
|
||||||
|
for t in message.tool_calls:
|
||||||
|
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
|
||||||
|
_process_content(content)
|
||||||
|
|
||||||
|
eom = False
|
||||||
|
if message.role == "assistant":
|
||||||
|
eom = message.stop_reason == StopReason.end_of_message
|
||||||
|
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|eom_id|>" if eom else "<|eot_id|>"])
|
||||||
|
return tokens, images
|
||||||
|
|
||||||
|
def encode_dialog_prompt(
|
||||||
|
self,
|
||||||
|
messages: List[RawMessage],
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
) -> LLMInput:
|
||||||
|
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
|
||||||
|
tokens = []
|
||||||
|
images = []
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||||
|
for message in messages:
|
||||||
|
toks, imgs = self.encode_message(message, tool_prompt_format)
|
||||||
|
tokens.extend(toks)
|
||||||
|
images.extend(imgs)
|
||||||
|
|
||||||
|
# Add the start of an assistant message for the model to complete.
|
||||||
|
tokens.extend(self._encode_header("assistant"))
|
||||||
|
|
||||||
|
return self._model_input_from_tokens_images(tokens, images)
|
||||||
|
|
||||||
|
# TODO(this should be generic, not only for assistant messages)
|
||||||
|
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
||||||
|
content = self.tokenizer.decode(tokens)
|
||||||
|
|
||||||
|
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||||
|
|
||||||
|
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
|
||||||
|
content = content.strip(" ")
|
||||||
|
header_str = self.possible_headers[Role.assistant]
|
||||||
|
if content.startswith(header_str):
|
||||||
|
content = content[len(header_str) :]
|
||||||
|
|
||||||
|
ipython = content.startswith("<|python_tag|>")
|
||||||
|
if ipython:
|
||||||
|
content = content[len("<|python_tag|>") :]
|
||||||
|
|
||||||
|
if content.endswith("<|eot_id|>"):
|
||||||
|
content = content[: -len("<|eot_id|>")]
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
elif content.endswith("<|eom_id|>"):
|
||||||
|
content = content[: -len("<|eom_id|>")]
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
|
tool_name = None
|
||||||
|
tool_arguments = {}
|
||||||
|
|
||||||
|
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
||||||
|
if custom_tool_info is not None:
|
||||||
|
tool_name, tool_arguments = custom_tool_info
|
||||||
|
# Sometimes when agent has custom tools alongside builin tools
|
||||||
|
# Agent responds for builtin tool calls in the format of the custom tools
|
||||||
|
# This code tries to handle that case
|
||||||
|
if tool_name in BuiltinTool.__members__:
|
||||||
|
tool_name = BuiltinTool[tool_name]
|
||||||
|
tool_arguments = {
|
||||||
|
"query": list(tool_arguments.values())[0],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
||||||
|
if builtin_tool_info is not None:
|
||||||
|
tool_name, query = builtin_tool_info
|
||||||
|
tool_arguments = {
|
||||||
|
"query": query,
|
||||||
|
}
|
||||||
|
if tool_name in BuiltinTool.__members__:
|
||||||
|
tool_name = BuiltinTool[tool_name]
|
||||||
|
elif ipython:
|
||||||
|
tool_name = BuiltinTool.code_interpreter
|
||||||
|
tool_arguments = {
|
||||||
|
"code": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
if tool_name is not None and tool_arguments is not None:
|
||||||
|
call_id = str(uuid.uuid4())
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
call_id=call_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
arguments=tool_arguments,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
content = ""
|
||||||
|
|
||||||
|
return RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=content,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
|
||||||
|
vision_input = None
|
||||||
|
if len(images) > 0:
|
||||||
|
vision_input = VisionInput(
|
||||||
|
mask=create_vision_mask(tokens, self.vision_token),
|
||||||
|
images=images,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LLMInput(
|
||||||
|
tokens=[128256 if token == self.vision_token else token for token in tokens],
|
||||||
|
vision=vision_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_vision_mask(
|
||||||
|
tokens: List[int],
|
||||||
|
vision_token: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
|
||||||
|
if len(vision_token_locations) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if len(vision_token_locations) == 1:
|
||||||
|
# only one image present, unmask until end of sequence
|
||||||
|
return [[vision_token_locations[0], -1]]
|
||||||
|
vision_masks = [
|
||||||
|
[loc1, loc2] for loc1, loc2 in zip(vision_token_locations[:-1], vision_token_locations[1:], strict=False)
|
||||||
|
]
|
||||||
|
# last image will attend to all subsequent text
|
||||||
|
vision_masks.append([vision_token_locations[-1], len(tokens)])
|
||||||
|
|
||||||
|
# if there are two or more consecutive vision tokens,
|
||||||
|
# they should all attend to all subsequent
|
||||||
|
# text present
|
||||||
|
last_mask_end = vision_masks[-1][1]
|
||||||
|
for vision_mask in vision_masks[::-1]:
|
||||||
|
if vision_mask[0] == vision_mask[1] - 1:
|
||||||
|
vision_mask[1] = last_mask_end
|
||||||
|
last_mask_end = vision_mask[1]
|
||||||
|
return vision_masks
|
||||||
|
|
@ -14,20 +14,19 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from termcolor import colored
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from termcolor import colored
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
|
||||||
|
|
||||||
from . import template_data
|
from . import template_data
|
||||||
|
from .chat_format import ChatFormat
|
||||||
from .prompt_templates import (
|
from .prompt_templates import (
|
||||||
BuiltinToolGenerator,
|
BuiltinToolGenerator,
|
||||||
FunctionTagCustomToolGenerator,
|
FunctionTagCustomToolGenerator,
|
||||||
|
|
@ -35,6 +34,7 @@ from .prompt_templates import (
|
||||||
SystemDefaultGenerator,
|
SystemDefaultGenerator,
|
||||||
ToolResponseGenerator,
|
ToolResponseGenerator,
|
||||||
)
|
)
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
THIS_DIR = Path(__file__).parent
|
THIS_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,8 @@ import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
|
||||||
BuiltinTool,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolParamDefinition,
|
ToolParamDefinition,
|
||||||
)
|
)
|
||||||
|
|
@ -226,10 +223,9 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||||
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
DEFAULT_PROMPT = textwrap.dedent(
|
DEFAULT_PROMPT = textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
|
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
|
||||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
|
||||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
|
||||||
also point it out. You should only return the function call in tools call sections.
|
|
||||||
|
|
||||||
{{ function_description }}
|
{{ function_description }}
|
||||||
""".strip("\n")
|
""".strip("\n")
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
|
|
||||||
128000
llama_stack/models/llama/llama3/tokenizer.model
Normal file
128000
llama_stack/models/llama/llama3/tokenizer.model
Normal file
File diff suppressed because it is too large
Load diff
214
llama_stack/models/llama/llama3/tokenizer.py
Normal file
214
llama_stack/models/llama/llama3/tokenizer.py
Normal file
|
|
@ -0,0 +1,214 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from logging import getLogger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import (
|
||||||
|
AbstractSet,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from tiktoken.load import load_tiktoken_bpe
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# The tiktoken tokenizer can handle <=400k chars without
|
||||||
|
# pyo3_runtime.PanicException.
|
||||||
|
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||||
|
|
||||||
|
# https://github.com/openai/tiktoken/issues/195
|
||||||
|
# Here we iterate over subsequences and split if we exceed the limit
|
||||||
|
# of max consecutive non-whitespace or whitespace characters.
|
||||||
|
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||||
|
|
||||||
|
|
||||||
|
_INSTANCE = None
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
"""
|
||||||
|
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
special_tokens: Dict[str, int]
|
||||||
|
|
||||||
|
num_reserved_special_tokens = 256
|
||||||
|
|
||||||
|
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
global _INSTANCE
|
||||||
|
|
||||||
|
if _INSTANCE is None:
|
||||||
|
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
||||||
|
return _INSTANCE
|
||||||
|
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
"""
|
||||||
|
Initializes the Tokenizer with a Tiktoken model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): The path to the Tiktoken model file.
|
||||||
|
"""
|
||||||
|
assert os.path.isfile(model_path), model_path
|
||||||
|
|
||||||
|
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||||
|
num_base_tokens = len(mergeable_ranks)
|
||||||
|
special_tokens = [
|
||||||
|
"<|begin_of_text|>",
|
||||||
|
"<|end_of_text|>",
|
||||||
|
"<|reserved_special_token_0|>",
|
||||||
|
"<|reserved_special_token_1|>",
|
||||||
|
"<|finetune_right_pad_id|>",
|
||||||
|
"<|step_id|>",
|
||||||
|
"<|start_header_id|>",
|
||||||
|
"<|end_header_id|>",
|
||||||
|
"<|eom_id|>", # end of message
|
||||||
|
"<|eot_id|>", # end of turn
|
||||||
|
"<|python_tag|>",
|
||||||
|
"<|image|>",
|
||||||
|
]
|
||||||
|
reserved_tokens = [
|
||||||
|
f"<|reserved_special_token_{2 + i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
|
||||||
|
]
|
||||||
|
special_tokens = special_tokens + reserved_tokens
|
||||||
|
|
||||||
|
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||||
|
self.model = tiktoken.Encoding(
|
||||||
|
name=Path(model_path).name,
|
||||||
|
pat_str=self.pat_str,
|
||||||
|
mergeable_ranks=mergeable_ranks,
|
||||||
|
special_tokens=self.special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.n_words: int = num_base_tokens + len(special_tokens)
|
||||||
|
# BOS / EOS token IDs
|
||||||
|
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
||||||
|
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
||||||
|
self.eot_id: int = self.special_tokens["<|eot_id|>"]
|
||||||
|
self.eom_id: int = self.special_tokens["<|eom_id|>"]
|
||||||
|
self.python_tag_id = self.special_tokens["<|python_tag|>"]
|
||||||
|
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
|
||||||
|
self.stop_tokens = [
|
||||||
|
self.eos_id,
|
||||||
|
self.special_tokens["<|eom_id|>"],
|
||||||
|
self.special_tokens["<|eot_id|>"],
|
||||||
|
]
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
s: str,
|
||||||
|
*,
|
||||||
|
bos: bool,
|
||||||
|
eos: bool,
|
||||||
|
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
||||||
|
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Encodes a string into a list of token IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s (str): The input string to be encoded.
|
||||||
|
bos (bool): Whether to prepend the beginning-of-sequence token.
|
||||||
|
eos (bool): Whether to append the end-of-sequence token.
|
||||||
|
allowed_special ("all"|set[str]): allowed special tokens in string
|
||||||
|
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[int]: A list of token IDs.
|
||||||
|
|
||||||
|
By default, setting disallowed_special=() encodes a string by ignoring
|
||||||
|
special tokens. Specifically:
|
||||||
|
- Setting `disallowed_special` to () will cause all text corresponding
|
||||||
|
to special tokens to be encoded as natural text (insteading of raising
|
||||||
|
an error).
|
||||||
|
- Setting `allowed_special` to "all" will treat all text corresponding
|
||||||
|
to special tokens to be encoded as special tokens.
|
||||||
|
"""
|
||||||
|
if allowed_special is None:
|
||||||
|
allowed_special = set()
|
||||||
|
assert type(s) is str
|
||||||
|
|
||||||
|
substrs = (
|
||||||
|
substr
|
||||||
|
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||||
|
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||||
|
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||||
|
)
|
||||||
|
)
|
||||||
|
t: List[int] = []
|
||||||
|
for substr in substrs:
|
||||||
|
t.extend(
|
||||||
|
self.model.encode(
|
||||||
|
substr,
|
||||||
|
allowed_special=allowed_special,
|
||||||
|
disallowed_special=disallowed_special,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if bos:
|
||||||
|
t.insert(0, self.bos_id)
|
||||||
|
if eos:
|
||||||
|
t.append(self.eos_id)
|
||||||
|
return t
|
||||||
|
|
||||||
|
def decode(self, t: Sequence[int]) -> str:
|
||||||
|
"""
|
||||||
|
Decodes a list of token IDs into a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (List[int]): The list of token IDs to be decoded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The decoded string.
|
||||||
|
"""
|
||||||
|
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||||
|
return self.model.decode(cast(List[int], t))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||||
|
"""
|
||||||
|
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||||
|
consecutive whitespaces or consecutive non-whitespaces.
|
||||||
|
"""
|
||||||
|
current_slice_len = 0
|
||||||
|
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||||
|
slice_start = 0
|
||||||
|
|
||||||
|
for i in range(len(s)):
|
||||||
|
is_now_space = s[i].isspace()
|
||||||
|
|
||||||
|
if current_slice_is_space ^ is_now_space:
|
||||||
|
current_slice_len = 1
|
||||||
|
current_slice_is_space = is_now_space
|
||||||
|
else:
|
||||||
|
current_slice_len += 1
|
||||||
|
if current_slice_len > max_consecutive_slice_len:
|
||||||
|
yield s[slice_start:i]
|
||||||
|
slice_start = i
|
||||||
|
current_slice_len = 1
|
||||||
|
yield s[slice_start:]
|
||||||
199
llama_stack/models/llama/llama3/tool_utils.py
Normal file
199
llama_stack/models/llama/llama3/tool_utils.py
Normal file
|
|
@ -0,0 +1,199 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
|
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||||
|
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||||
|
|
||||||
|
|
||||||
|
def is_json(s):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(s)
|
||||||
|
# Return True for valid objects and not for ints, strings, etc
|
||||||
|
return isinstance(parsed, dict)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_python_list(input_string):
|
||||||
|
"""Check if the input string is a valid Python list of function calls"""
|
||||||
|
try:
|
||||||
|
# Try to parse the string
|
||||||
|
tree = ast.parse(input_string)
|
||||||
|
|
||||||
|
# Check if it's a single expression
|
||||||
|
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the expression is a list
|
||||||
|
expr = tree.body[0].value
|
||||||
|
if not isinstance(expr, ast.List):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the list is empty
|
||||||
|
if len(expr.elts) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if all elements in the list are function calls
|
||||||
|
for element in expr.elts:
|
||||||
|
if not isinstance(element, ast.Call):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the function call has a valid name
|
||||||
|
if not isinstance(element.func, ast.Name):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if all arguments are keyword arguments
|
||||||
|
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except SyntaxError:
|
||||||
|
# If parsing fails, it's not a valid Python expression
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_python_list_for_function_calls(input_string):
|
||||||
|
"""
|
||||||
|
Parse a Python list of function calls and
|
||||||
|
return a list of tuples containing the function name and arguments
|
||||||
|
"""
|
||||||
|
# Parse the string into an AST
|
||||||
|
tree = ast.parse(input_string)
|
||||||
|
|
||||||
|
# Ensure the input is a list
|
||||||
|
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
|
||||||
|
raise ValueError("Input must be a list of function calls")
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# Iterate through each function call in the list
|
||||||
|
for node in tree.body[0].value.elts:
|
||||||
|
if isinstance(node, ast.Call):
|
||||||
|
function_name = node.func.id
|
||||||
|
function_args = {}
|
||||||
|
|
||||||
|
# Extract keyword arguments
|
||||||
|
for keyword in node.keywords:
|
||||||
|
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||||
|
|
||||||
|
result.append((function_name, function_args))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ToolUtils:
|
||||||
|
@staticmethod
|
||||||
|
def is_builtin_tool_call(message_body: str) -> bool:
|
||||||
|
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
|
||||||
|
return match is not None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||||
|
# Find the first match in the text
|
||||||
|
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
||||||
|
|
||||||
|
# Check if a match is found and return it
|
||||||
|
if match:
|
||||||
|
tool_name = match.group("tool_name")
|
||||||
|
query = match.group("query")
|
||||||
|
return tool_name, query
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||||
|
# NOTE: Custom function too calls are still experimental
|
||||||
|
# Sometimes, response is of the form
|
||||||
|
# {"type": "function", "name": "function_name", "parameters": {...}
|
||||||
|
# and some times
|
||||||
|
# <function=function_name>(parameters)</function>
|
||||||
|
|
||||||
|
# Find the first match in the text
|
||||||
|
match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
|
||||||
|
if match:
|
||||||
|
tool_name = match.group("function_name")
|
||||||
|
query = match.group("args")
|
||||||
|
try:
|
||||||
|
return tool_name, json.loads(query.replace("'", '"'))
|
||||||
|
except Exception as e:
|
||||||
|
print("Exception while parsing json query for custom tool call", query, e)
|
||||||
|
return None
|
||||||
|
elif is_json(message_body):
|
||||||
|
response = json.loads(message_body)
|
||||||
|
if ("type" in response and response["type"] == "function") or ("name" in response):
|
||||||
|
function_name = response["name"]
|
||||||
|
args = response["parameters"]
|
||||||
|
return function_name, args
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
elif is_valid_python_list(message_body):
|
||||||
|
res = parse_python_list_for_function_calls(message_body)
|
||||||
|
# FIXME: Enable multiple tool calls
|
||||||
|
return res[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
||||||
|
if t.tool_name == BuiltinTool.brave_search:
|
||||||
|
q = t.arguments["query"]
|
||||||
|
return f'brave_search.call(query="{q}")'
|
||||||
|
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
||||||
|
q = t.arguments["query"]
|
||||||
|
return f'wolfram_alpha.call(query="{q}")'
|
||||||
|
elif t.tool_name == BuiltinTool.photogen:
|
||||||
|
q = t.arguments["query"]
|
||||||
|
return f'photogen.call(query="{q}")'
|
||||||
|
elif t.tool_name == BuiltinTool.code_interpreter:
|
||||||
|
return t.arguments["code"]
|
||||||
|
else:
|
||||||
|
fname = t.tool_name
|
||||||
|
|
||||||
|
if tool_prompt_format == ToolPromptFormat.json:
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": fname,
|
||||||
|
"parameters": t.arguments,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
||||||
|
args = json.dumps(t.arguments)
|
||||||
|
return f"<function={fname}>{args}</function>"
|
||||||
|
|
||||||
|
elif tool_prompt_format == ToolPromptFormat.python_list:
|
||||||
|
|
||||||
|
def format_value(value: RecursiveType) -> str:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return f'"{value}"'
|
||||||
|
elif isinstance(value, (int, float, bool)) or value is None:
|
||||||
|
return str(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
return f"[{', '.join(format_value(v) for v in value)}]"
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported type: {type(value)}")
|
||||||
|
|
||||||
|
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
||||||
|
return f"[{fname}({args_str})]"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
||||||
358
llama_stack/models/llama/llama3_1/prompt_format.md
Normal file
358
llama_stack/models/llama/llama3_1/prompt_format.md
Normal file
|
|
@ -0,0 +1,358 @@
|
||||||
|
|
||||||
|
|
||||||
|
# Llama 3.1 - Prompt Formats
|
||||||
|
## Tokens
|
||||||
|
Here is a list of special tokens that are supported by Llama 3.1:
|
||||||
|
- `<|begin_of_text|>`: Specifies the start of the prompt
|
||||||
|
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
|
||||||
|
- `<|finetune_right_pad_id|>`: This token is used for padding text sequences to the same length in a batch.
|
||||||
|
- `<|start_header_id|>` and `<|end_header_id|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user, assistant and ipython]
|
||||||
|
- `<|eom_id|>`: End of message. A message represents a possible stopping point for execution where the model can inform the executor that a tool call needs to be made. This is used for multi-step interactions between the model and any available tools. This token is emitted by the model when the Environment: ipython instruction is used in the system prompt, or if the model calls for a built-in tool.
|
||||||
|
- `<|eot_id|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
|
||||||
|
- at the end of a direct interaction between the model and the user
|
||||||
|
- at the end of multiple interactions between the model and any available tools
|
||||||
|
This token signals to the executor that the model has finished generating a response.
|
||||||
|
- `<|python_tag|>`: Is a special tag used in the model's response to signify a tool call.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
There are 4 different roles that are supported by Llama 3.1
|
||||||
|
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
|
||||||
|
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
|
||||||
|
- `ipython`: A new role introduced in Llama 3.1. Semantically, this role means "tool". This role is used to mark messages with the output of a tool call when sent back to the model from the executor.
|
||||||
|
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `ipython` and `user` prompts.
|
||||||
|
|
||||||
|
## Llama 3.1 Base Model
|
||||||
|
|
||||||
|
Text completion for Llama 3.1 base model uses this format.
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|>Color of sky is blue but sometimes can also be
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
red, orange, yellow, green, purple, pink, brown, gray, black, white, and even rainbow colors. The color of the sky can change due to various reasons such as time of day, weather conditions, pollution, and atmospheric phenomena.
|
||||||
|
The color of the sky is primarily blue because of a phenomenon called
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Note start special tag
|
||||||
|
|
||||||
|
|
||||||
|
## Llama 3.1 Instruct Model
|
||||||
|
## User and assistant conversation
|
||||||
|
|
||||||
|
Here is a regular multi-turn user assistant conversation and how its formatted.
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Answer who are you in the form of jeopardy?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
Here's my response
|
||||||
|
|
||||||
|
"What is a helpful assistant?"<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Tool Calling Formats
|
||||||
|
|
||||||
|
|
||||||
|
The three built-in tools (brave_search, wolfram_alpha, and code interpreter) can be turned on using the system prompt:
|
||||||
|
- Brave Search: Tool call to perform web searches.
|
||||||
|
- Wolfram Alpha: Tool call to perform complex mathematical calculations.
|
||||||
|
- Code Interpreter: Enables the model to output python code.
|
||||||
|
|
||||||
|
## Builtin Tool Calling
|
||||||
|
|
||||||
|
|
||||||
|
Here is an example of a conversation using brave search
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython
|
||||||
|
Tools: brave_search, wolfram_alpha
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: 21 September 2024
|
||||||
|
|
||||||
|
You are a helpful assistant.
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- Just including Environment: ipython turns on code interpreter; therefore, you don't need to specify code interpretation on the Tools: line. The model can generate python code which is interpreted by the executor, with the result provided back to the model.
|
||||||
|
- The message body of the assistant response starts with a special tag <|python_tag|>
|
||||||
|
- As alluded to above, in such an environment, the model can generate <|eom_id|> instead of just the standard <|eot_id|> . The latter indicates the turn is finished, while the former indicates continued multi-step reasoning. That is, the model is expecting a continuation message with the output of the tool call.
|
||||||
|
- The model tool call response is of the form `tool.call(query="...")` wher tool is `brave_search` or `wolfram_alpha`
|
||||||
|
|
||||||
|
|
||||||
|
## Builtin Code Interpreter
|
||||||
|
|
||||||
|
Here is an actual example of model responding with code
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Write code to check if number is prime, use that to see if the number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
<|python_tag|>def is_prime(n):
|
||||||
|
if n <= 1
|
||||||
|
return False
|
||||||
|
for i in range(2, int(n**0.5) + 1):
|
||||||
|
if n % i == 0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
print(is_prime(7)) # Output: True<|eom_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- Model starts with <|python_tag|> and continues writing python code that it needs to be executed
|
||||||
|
- No explicit mention of code_interpreter in system prompt. `Environment: ipython` implicitly enables it.
|
||||||
|
|
||||||
|
|
||||||
|
## Built-in tools full interaction
|
||||||
|
|
||||||
|
Here is a full interaction with the built-in tools including the tool response and the final assistant response.
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython
|
||||||
|
Tools: brave_search, wolfram_alpha
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What is the 100th decimal of pi?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
<|python_tag|>wolfram_alpha.call(query="100th decimal of pi")<|eom_id|><|start_header_id|>ipython<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
{
|
||||||
|
"queryresult": {
|
||||||
|
"success": true,
|
||||||
|
"inputstring": "100th decimal of pi",
|
||||||
|
"pods": [
|
||||||
|
{
|
||||||
|
"title": "Input interpretation",
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "100th digit | π"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Nearby digits",
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "...86208998628034825342117067982148086513282306647093..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Result",
|
||||||
|
"primary": true,
|
||||||
|
"subpods": [
|
||||||
|
{
|
||||||
|
"title": "",
|
||||||
|
"plaintext": "7"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
The 100th decimal of pi is 7.<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- Note the `<|python_tag|>` in the assistant response.
|
||||||
|
- Role is `ipython` for the wolfram alpha response that is passed back to the model.
|
||||||
|
- Final message from assistant has <|eot_id|> tag.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Zero shot tool calling
|
||||||
|
## JSON based tool calling
|
||||||
|
|
||||||
|
|
||||||
|
Llama models can now output custom tool calls from a single message to allow easier tool calling.
|
||||||
|
The following prompts provide an example of how custom tools can be called from the output of the model.
|
||||||
|
It's important to note that the model itself does not execute the calls; it provides structured output to facilitate calling by an executor.
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython
|
||||||
|
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: 21 September 2024
|
||||||
|
|
||||||
|
You are a helpful assistant.
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Answer the user's question by making use of the following functions if needed.
|
||||||
|
If none of the function can be used, please say so.
|
||||||
|
Here is a list of functions in JSON format:
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "trending_songs",
|
||||||
|
"description": "Returns the trending songs on a Music site",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": [
|
||||||
|
{
|
||||||
|
"n": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The number of songs to return"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"genre": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "The genre of the songs to return"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"required": ["n"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Return function calls in JSON format.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
<|python_tag|>{
|
||||||
|
"type": "function",
|
||||||
|
"name": "trending_songs",
|
||||||
|
"parameters": {
|
||||||
|
"n": "10",
|
||||||
|
"genre": "all"
|
||||||
|
}
|
||||||
|
}<|eom_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- JSON format for providing tools needs name, description and parameters
|
||||||
|
- Model responds with `<|python_tag|>` and `<|eom_id|>` as `Environment: ipython` was in the system prompt
|
||||||
|
- Instructions for tools added as a user message
|
||||||
|
- Only single tool calls are supported as of now
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Example of a user defined tool calling
|
||||||
|
## `<function>` based tool calling
|
||||||
|
|
||||||
|
|
||||||
|
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
|
||||||
|
In this example, we define a custom tool calling format using the `<function>` tag.
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython
|
||||||
|
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: 21 September 2024
|
||||||
|
|
||||||
|
You are a helpful assistant.
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
You have access to the following functions:
|
||||||
|
|
||||||
|
Use the function 'trending_songs' to 'Returns the trending songs on a Music site':
|
||||||
|
{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}
|
||||||
|
|
||||||
|
Think very carefully before calling functions.
|
||||||
|
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||||
|
|
||||||
|
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||||
|
|
||||||
|
Reminder:
|
||||||
|
- If looking for real time information use relevant functions before falling back to brave_search
|
||||||
|
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||||
|
- Required parameters MUST be specified
|
||||||
|
- Only call one function at a time
|
||||||
|
- Put the entire function call reply on one line<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Use tools to get latest trending songs<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
<function=trending_songs>{"n": 10}</function><|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
- In this case, model does NOT respond with `<|python_tag|>` and ends with `<|eot_id|>`
|
||||||
|
- Instructions for tools added as a user message
|
||||||
|
|
||||||
|
|
||||||
|
Thank You!
|
||||||
|
|
@ -14,7 +14,7 @@
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@
|
||||||
import json
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@
|
||||||
import textwrap
|
import textwrap
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
RawTextItem,
|
RawTextItem,
|
||||||
|
|
|
||||||
286
llama_stack/models/llama/llama3_2/text_prompt_format.md
Normal file
286
llama_stack/models/llama/llama3_2/text_prompt_format.md
Normal file
|
|
@ -0,0 +1,286 @@
|
||||||
|
## User and assistant conversation
|
||||||
|
|
||||||
|
Here is a regular multi-turn user assistant conversation and how its formatted.
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
I'm an AI assistant, which means I'm a computer program designed to simulate conversation and answer questions to the best of my ability. I'm here to help you with any questions or tasks you may have, and I'll do my best to provide you with accurate and helpful information.
|
||||||
|
|
||||||
|
I don't have a personal name, but I'm often referred to as a "virtual assistant" or a "chatbot." I'm a machine learning model, which means I've been trained on a large dataset of text and can generate responses based on patterns and context.
|
||||||
|
|
||||||
|
I can help with a wide range of topics, from general knowledge and trivia to more specialized subjects like science, history, and technology. I can also assist with tasks like language translation, text summarization, and even generating creative content like stories or poetry.
|
||||||
|
|
||||||
|
So, what can I help you with today?<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
This format is unchanged from Llama3.1
|
||||||
|
|
||||||
|
## Zero shot function calling
|
||||||
|
|
||||||
|
|
||||||
|
For Llama3.2 1B and 3B instruct models, we are introducing a new format for zero shot function calling.
|
||||||
|
This new format is designed to be more flexible and powerful than the previous format.
|
||||||
|
All available functions can be provided in the system message. A key difference is in the format of how the assistant responds with function calls.
|
||||||
|
It is pythonic in the form of `[func1(params_name=params_value, params_name2=params_value2...), func2(params)]` instead of the `json` or `<function>` tag that were defined in Llama3.1.
|
||||||
|
Here is an example for the same,
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||||
|
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||||
|
also point it out. You should only return the function call in tools call sections.
|
||||||
|
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather info for places",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": [
|
||||||
|
"city"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city to get the weather for"
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
"default": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What is the weather in SF and Seattle?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
[get_weather(city='San Francisco', metric='celsius'), get_weather(city='Seattle', metric='celsius')]<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
|
||||||
|
- The output supports multiple tool calls natively
|
||||||
|
- JSON format for defining the functions in the system prompt is similar to Llama3.1
|
||||||
|
|
||||||
|
|
||||||
|
## Zero shot function calling with user message
|
||||||
|
|
||||||
|
|
||||||
|
While the default is to provide all function calls in a system message, in Llama3.2 text models you can also provide information for all the available tools in a user message.
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
|
||||||
|
Here is a list of functions in JSON format that you can invoke:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_user_info",
|
||||||
|
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": [
|
||||||
|
"user_id"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"user_id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
|
||||||
|
},
|
||||||
|
"special": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Any special information or parameters that need to be considered while fetching user details.",
|
||||||
|
"default": "none"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
|
||||||
|
|
||||||
|
NO other text MUST be included.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
[get_user_info(user_id=7890, special='black')]<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
|
||||||
|
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
|
||||||
|
- While builtin tool calls end with a <|eom_id|>, notice the <|eot_id|> for zero shot tool calls.
|
||||||
|
|
||||||
|
|
||||||
|
## Code Interpreter
|
||||||
|
|
||||||
|
|
||||||
|
Code Interpreter continues to work in 3.2 text models similar to Llama 3.1 model family.
|
||||||
|
Here is an example,
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython
|
||||||
|
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: 24 September 2024
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Write code to check if number is prime. Use it to verify if number 7 is prime<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
<|python_tag|>def is_prime(n):
|
||||||
|
if n <= 1:
|
||||||
|
return False
|
||||||
|
if n == 2:
|
||||||
|
return True
|
||||||
|
if n % 2 == 0:
|
||||||
|
return False
|
||||||
|
max_divisor = int(n**0.5) + 1
|
||||||
|
for d in range(3, max_divisor, 2):
|
||||||
|
if n % d == 0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
print(is_prime(7)) # Output: True<|eom_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
|
||||||
|
- Note `Environment: ipython` in the system prompt.
|
||||||
|
- Note that the response starts with `<|python_tag|>` and ends with `<|eom_id|>`
|
||||||
|
|
||||||
|
|
||||||
|
## Zero shot function calling E2E format
|
||||||
|
|
||||||
|
|
||||||
|
Here is an example of the e2e cycle of tool calls with the model in a muti-step way.
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||||
|
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||||
|
also point it out. You should only return the function call in tools call sections.
|
||||||
|
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
You SHOULD NOT include any other text in the response.
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather info for places",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": [
|
||||||
|
"city"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city to get the weather for"
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
"default": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What is the weather in SF?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
<|python_tag|>[get_weather(city="San Francisco", metric="celsius")]<|eot_id|><|start_header_id|>ipython<|end_header_id|>
|
||||||
|
|
||||||
|
"25 C"<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
The weather in San Francisco is 25 C.<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
|
||||||
|
- The output of the function call is provided back to the model as a tool response ( in json format ).
|
||||||
|
- Notice `<|start_header_id|>ipython<|end_header_id|>` as the header message preceding the tool response.
|
||||||
|
- The model finally summarizes the information from the tool response and returns the result to the user.
|
||||||
|
|
||||||
|
|
||||||
|
## Prompt format for base models
|
||||||
|
|
||||||
|
|
||||||
|
For base models (Llama3.2-1B and Llama3.2-3B), the prompt format for a simple completion is as follows
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|>The color of the sky is blue but sometimes it can also be
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
gray or even purple. The color of the sky can change depending on the time of day, the weather, and the amount of pollution in the air. The color of the sky can also be affected by the presence of dust, smoke, and other particles in the air.
|
||||||
|
|
||||||
|
## Step 1: Identify the factors that
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
Same as Llama3.1
|
||||||
|
|
||||||
|
Thank You!
|
||||||
141
llama_stack/models/llama/llama3_2/vision_prompt_format.md
Normal file
141
llama_stack/models/llama/llama3_2/vision_prompt_format.md
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
## User and assistant conversation
|
||||||
|
|
||||||
|
Here is a regular multi-turn user assistant conversation and how its formatted.
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Who are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
I'm a helpful assistant, here to provide information, answer questions, and assist with tasks to the best of my abilities. I'm a large language model, which means I can understand and respond to natural language inputs, and I'm constantly learning and improving to provide more accurate and helpful responses.
|
||||||
|
|
||||||
|
I can help with a wide range of topics, from general knowledge and trivia to more specific areas like science, history, technology, and more. I can also assist with tasks like language translation, text summarization, and even generating creative content like stories or dialogues.
|
||||||
|
|
||||||
|
What can I help you with today?<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
This format is unchanged from Llama3.1
|
||||||
|
|
||||||
|
## User and assistant conversation with Images
|
||||||
|
|
||||||
|
This example shows how to pass and image to the model as part of the messages.
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
<|image|>Describe this image in two sentences<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
The image depicts a small dog standing on a skateboard, with its front paws firmly planted on the board and its back paws slightly raised. The dog's fur is predominantly brown and white, with a distinctive black stripe running down its back, and it is wearing a black collar around its neck.<|eot_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
|
||||||
|
- The `<|image|>` tag is used to indicate presence of the image
|
||||||
|
- The model isn't an early fusion model so doesn't actually translate an image into several tokens. Instead the cross-attention layers take input "on the side" from a vision encoder
|
||||||
|

|
||||||
|
- Its important to postion the <|image|> tag appropriately in the prompt. Image will only attend to the subsequent text tokens
|
||||||
|
- The <|image|> tag is part of the user message body, implying that it should only come after the header `<|start_header_id|>{role}<|end_header_id|>` in the message body
|
||||||
|
- We recommend using a single image in one prompt
|
||||||
|
|
||||||
|
|
||||||
|
## Builtin and Zero Shot Tool Calling
|
||||||
|
|
||||||
|
|
||||||
|
Llama3.2 vision models follow the same tool calling format as Llama3.1 models when inputs are text only.
|
||||||
|
Use `Environment: ipython` to enable tools.
|
||||||
|
Add `Tools: {{tool_name1}},{{tool_name2}}` for each of the builtin tools.
|
||||||
|
The same builtin tools as Llama3.1 are available,
|
||||||
|
- code_interpreter (for executing python code)
|
||||||
|
- brave_search (to search the web)
|
||||||
|
- wolfram_alpha (for querying wolfram alpha for mathematical questions)
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
Environment: ipython
|
||||||
|
Tools: brave_search, wolfram_alpha
|
||||||
|
Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: 23 September 2024
|
||||||
|
|
||||||
|
You are a helpful assistant.
|
||||||
|
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Search the web for the latest price of 1oz gold?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
<|python_tag|>brave_search.call(query="latest price of 1oz gold")<|eom_id|>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
|
||||||
|
- Note the `<|python_tag|>` before `brave_search` function call.
|
||||||
|
- The `<|eom_id|>` tag is used to indicate the end of the message.
|
||||||
|
- Similar to Llama3.1, code_interpreter is not explicitly mentioned but is enabled via `Environment: ipython`.
|
||||||
|
- Tool Calling does NOT work with images in the prompt as of now.
|
||||||
|
|
||||||
|
|
||||||
|
## Prompt format for base models
|
||||||
|
|
||||||
|
|
||||||
|
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), the prompt format for a simple completion is as follows
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|>The color of the sky is blue but sometimes it can also be
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
red, orange, pink, purple, and even black. The color of the sky is determined by the amount of sunlight that is scattered by the atmosphere and the amount of dust and water vapor present in the atmosphere. During sunrise and sunset, the sky can take on a range of colors due to the scattering of light by
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
- Same as Llama3.1
|
||||||
|
|
||||||
|
## Prompt format for base models with Image
|
||||||
|
|
||||||
|
|
||||||
|
For base models (Llama3.2-11B-Vision and Llama3.2-90B-Vision), here is an example of how the text completion format looks with an image,
|
||||||
|
|
||||||
|
|
||||||
|
##### Input Prompt Format
|
||||||
|
```
|
||||||
|
<|begin_of_text|><|image|>If I had to write a haiku for this one
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Model Response Format
|
||||||
|
```
|
||||||
|
, it would be: A skateboarder's delight, a puppy on a board, a furry little thrill-seeker. This puppy is a true skateboarding enthusiast, always eager to hit the streets and show off his skills. He's a master of the board, gliding effortlessly across the pavement with grace and style.
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Notes
|
||||||
|
- Note the placement of the special tags <|begin_of_text|> and <|image|>
|
||||||
|
|
||||||
|
Thank You!
|
||||||
|
|
@ -14,7 +14,7 @@
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,9 @@ import textwrap
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
RawContent,
|
RawContent,
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
|
|
@ -25,7 +27,6 @@ from llama_models.datatypes import (
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from .llama3.interface import LLama31Interface
|
from .llama3.interface import LLama31Interface
|
||||||
from .llama3.template_data import (
|
from .llama3.template_data import (
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
|
|
@ -17,8 +16,8 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import TypeAdapter
|
|
||||||
|
|
||||||
|
from llama_stack import logcat
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
AgentToolGroup,
|
AgentToolGroup,
|
||||||
|
|
@ -62,7 +61,12 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime
|
from llama_stack.apis.tools import (
|
||||||
|
RAGDocument,
|
||||||
|
ToolGroups,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
|
|
@ -70,21 +74,18 @@ from llama_stack.models.llama.datatypes import (
|
||||||
ToolParamDefinition,
|
ToolParamDefinition,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
from .persistence import AgentPersistence
|
from .persistence import AgentPersistence
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
def make_random_string(length: int = 8):
|
||||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
MEMORY_QUERY_TOOL = "query_from_memory"
|
MEMORY_QUERY_TOOL = "knowledge_search"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
RAG_TOOL_GROUP = "builtin::rag"
|
RAG_TOOL_GROUP = "builtin::rag"
|
||||||
|
|
||||||
|
|
@ -122,13 +123,25 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
# We do not want to keep adding RAG context to the input messages
|
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||||
# May be this should be a parameter of the agentic instance
|
tool_call_ids = set()
|
||||||
# that can define its behavior in a custom way
|
for step in turn.steps:
|
||||||
|
if step.step_type == StepType.tool_execution.value:
|
||||||
|
for response in step.tool_responses:
|
||||||
|
tool_call_ids.add(response.call_id)
|
||||||
|
|
||||||
for m in turn.input_messages:
|
for m in turn.input_messages:
|
||||||
msg = m.model_copy()
|
msg = m.model_copy()
|
||||||
|
# We do not want to keep adding RAG context to the input messages
|
||||||
|
# May be this should be a parameter of the agentic instance
|
||||||
|
# that can define its behavior in a custom way
|
||||||
if isinstance(msg, UserMessage):
|
if isinstance(msg, UserMessage):
|
||||||
msg.context = None
|
msg.context = None
|
||||||
|
if isinstance(msg, ToolResponseMessage):
|
||||||
|
if msg.call_id in tool_call_ids:
|
||||||
|
# NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps
|
||||||
|
continue
|
||||||
|
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
||||||
for step in turn.steps:
|
for step in turn.steps:
|
||||||
|
|
@ -183,7 +196,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
span.set_attribute("turn_id", turn_id)
|
span.set_attribute("turn_id", turn_id)
|
||||||
start_time = datetime.now()
|
start_time = datetime.now().astimezone().isoformat()
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseTurnStartPayload(
|
payload=AgentTurnResponseTurnStartPayload(
|
||||||
|
|
@ -204,8 +217,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
toolgroups_for_turn=request.toolgroups,
|
toolgroups_for_turn=request.toolgroups,
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
log.info(
|
logcat.info(
|
||||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
"agents",
|
||||||
|
f"returning result from the agent turn: {chunk}",
|
||||||
)
|
)
|
||||||
output_message = chunk
|
output_message = chunk
|
||||||
continue
|
continue
|
||||||
|
|
@ -225,12 +239,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages=request.messages,
|
input_messages=request.messages,
|
||||||
output_message=output_message,
|
output_message=output_message,
|
||||||
started_at=start_time,
|
started_at=start_time,
|
||||||
completed_at=datetime.now(),
|
completed_at=datetime.now().astimezone().isoformat(),
|
||||||
steps=steps,
|
steps=steps,
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
if output_message.tool_calls:
|
||||||
if output_message.tool_calls and request.allow_turn_resume:
|
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||||
|
|
@ -262,17 +275,24 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
turns = await self.storage.get_session_turns(request.session_id)
|
turns = await self.storage.get_session_turns(request.session_id)
|
||||||
|
if len(turns) == 0:
|
||||||
|
raise ValueError("No turns found for session")
|
||||||
|
|
||||||
messages = await self.get_messages_from_turns(turns)
|
messages = await self.get_messages_from_turns(turns)
|
||||||
messages.extend(request.tool_responses)
|
messages.extend(request.tool_responses)
|
||||||
|
|
||||||
|
last_turn = turns[-1]
|
||||||
|
last_turn_messages = self.turn_to_messages(last_turn)
|
||||||
last_turn_messages = [
|
last_turn_messages = [
|
||||||
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# TODO: figure out whether we should add the tool responses to the last turn messages
|
||||||
|
last_turn_messages.extend(request.tool_responses)
|
||||||
|
|
||||||
# get the steps from the turn id
|
# get the steps from the turn id
|
||||||
steps = []
|
steps = []
|
||||||
if len(turns) > 0:
|
steps = turns[-1].steps
|
||||||
steps = turns[-1].steps
|
|
||||||
|
|
||||||
# mark tool execution step as complete
|
# mark tool execution step as complete
|
||||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||||
|
|
@ -280,7 +300,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||||
request.session_id, request.turn_id
|
request.session_id, request.turn_id
|
||||||
)
|
)
|
||||||
now = datetime.now()
|
now = datetime.now().astimezone().isoformat()
|
||||||
tool_execution_step = ToolExecutionStep(
|
tool_execution_step = ToolExecutionStep(
|
||||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||||
turn_id=request.turn_id,
|
turn_id=request.turn_id,
|
||||||
|
|
@ -328,7 +348,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
assert output_message is not None
|
assert output_message is not None
|
||||||
|
|
||||||
last_turn_start_time = datetime.now()
|
last_turn_start_time = datetime.now().astimezone().isoformat()
|
||||||
if len(turns) > 0:
|
if len(turns) > 0:
|
||||||
last_turn_start_time = turns[-1].started_at
|
last_turn_start_time = turns[-1].started_at
|
||||||
|
|
||||||
|
|
@ -338,7 +358,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages=last_turn_messages,
|
input_messages=last_turn_messages,
|
||||||
output_message=output_message,
|
output_message=output_message,
|
||||||
started_at=last_turn_start_time,
|
started_at=last_turn_start_time,
|
||||||
completed_at=datetime.now(),
|
completed_at=datetime.now().astimezone().isoformat(),
|
||||||
steps=steps,
|
steps=steps,
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
@ -432,7 +452,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return
|
return
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
shield_call_start_time = datetime.now()
|
shield_call_start_time = datetime.now().astimezone().isoformat()
|
||||||
try:
|
try:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
|
@ -456,7 +476,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
violation=e.violation,
|
violation=e.violation,
|
||||||
started_at=shield_call_start_time,
|
started_at=shield_call_start_time,
|
||||||
completed_at=datetime.now(),
|
completed_at=datetime.now().astimezone().isoformat(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -479,7 +499,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
violation=None,
|
violation=None,
|
||||||
started_at=shield_call_start_time,
|
started_at=shield_call_start_time,
|
||||||
completed_at=datetime.now(),
|
completed_at=datetime.now().astimezone().isoformat(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -499,121 +519,37 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# TODO: simplify all of this code, it can be simpler
|
# TODO: simplify all of this code, it can be simpler
|
||||||
toolgroup_args = {}
|
toolgroup_args = {}
|
||||||
toolgroups = set()
|
toolgroups = set()
|
||||||
for toolgroup in self.agent_config.toolgroups:
|
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
toolgroups.add(toolgroup.name)
|
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
|
||||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
toolgroups.add(tool_group_name)
|
||||||
|
toolgroup_args[tool_group_name] = toolgroup.args
|
||||||
else:
|
else:
|
||||||
toolgroups.add(toolgroup)
|
toolgroups.add(toolgroup)
|
||||||
if toolgroups_for_turn:
|
|
||||||
for toolgroup in toolgroups_for_turn:
|
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
|
||||||
toolgroups.add(toolgroup.name)
|
|
||||||
toolgroup_args[toolgroup.name] = toolgroup.args
|
|
||||||
else:
|
|
||||||
toolgroups.add(toolgroup)
|
|
||||||
|
|
||||||
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||||
if documents:
|
if documents:
|
||||||
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
await self.handle_documents(session_id, documents, input_messages, tool_defs)
|
||||||
|
|
||||||
if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0:
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
step_id = str(uuid.uuid4())
|
if session_info and session_info.vector_db_id:
|
||||||
yield AgentTurnResponseStreamChunk(
|
if RAG_TOOL_GROUP not in toolgroup_args:
|
||||||
event=AgentTurnResponseEvent(
|
toolgroup_args[RAG_TOOL_GROUP] = {"vector_db_ids": [session_info.vector_db_id]}
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
else:
|
||||||
step_type=StepType.tool_execution.value,
|
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
|
||||||
step_id=step_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
args = toolgroup_args.get(RAG_TOOL_GROUP, {})
|
|
||||||
vector_db_ids = args.get("vector_db_ids", [])
|
|
||||||
query_config = args.get("query_config")
|
|
||||||
if query_config:
|
|
||||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
|
||||||
else:
|
|
||||||
# handle someone passing an empty dict
|
|
||||||
query_config = RAGQueryConfig()
|
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
|
||||||
|
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
|
||||||
if session_info.vector_db_id:
|
|
||||||
vector_db_ids.append(session_info.vector_db_id)
|
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
|
||||||
step_type=StepType.tool_execution.value,
|
|
||||||
step_id=step_id,
|
|
||||||
delta=ToolCallDelta(
|
|
||||||
parse_status=ToolCallParseStatus.succeeded,
|
|
||||||
tool_call=ToolCall(
|
|
||||||
call_id="",
|
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
|
||||||
arguments={},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = await self.tool_runtime_api.rag_tool.query(
|
|
||||||
content=concat_interleaved_content([msg.content for msg in input_messages]),
|
|
||||||
vector_db_ids=vector_db_ids,
|
|
||||||
query_config=query_config,
|
|
||||||
)
|
|
||||||
retrieved_context = result.content
|
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
|
||||||
step_type=StepType.tool_execution.value,
|
|
||||||
step_id=step_id,
|
|
||||||
step_details=ToolExecutionStep(
|
|
||||||
step_id=step_id,
|
|
||||||
turn_id=turn_id,
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="",
|
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
|
||||||
arguments={},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
tool_responses=[
|
|
||||||
ToolResponse(
|
|
||||||
call_id="",
|
|
||||||
tool_name=MEMORY_QUERY_TOOL,
|
|
||||||
content=retrieved_context or [],
|
|
||||||
metadata=result.metadata,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
|
|
||||||
span.set_attribute("output", retrieved_context)
|
|
||||||
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
|
||||||
|
|
||||||
# append retrieved_context to the last user message
|
|
||||||
for message in input_messages[::-1]:
|
|
||||||
if isinstance(message, UserMessage):
|
|
||||||
message.context = retrieved_context
|
|
||||||
break
|
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||||
|
|
||||||
# Build a map of custom tools to their definitions for faster lookup
|
# Build a map of custom tools to their definitions for faster lookup
|
||||||
client_tools = {}
|
client_tools = {}
|
||||||
for tool in self.agent_config.client_tools:
|
for tool in self.agent_config.client_tools:
|
||||||
client_tools[tool.name] = tool
|
client_tools[tool.name] = tool
|
||||||
while True:
|
while True:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
inference_start_time = datetime.now()
|
inference_start_time = datetime.now().astimezone().isoformat()
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
|
|
@ -631,9 +567,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=[
|
tools=tool_defs,
|
||||||
tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP
|
|
||||||
],
|
|
||||||
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
|
||||||
response_format=self.agent_config.response_format,
|
response_format=self.agent_config.response_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
@ -683,8 +617,20 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if event.stop_reason is not None:
|
if event.stop_reason is not None:
|
||||||
stop_reason = event.stop_reason
|
stop_reason = event.stop_reason
|
||||||
span.set_attribute("stop_reason", stop_reason)
|
span.set_attribute("stop_reason", stop_reason)
|
||||||
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
|
span.set_attribute(
|
||||||
span.set_attribute("output", f"content: {content} tool_calls: {tool_calls}")
|
"input",
|
||||||
|
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||||
|
)
|
||||||
|
output_attr = json.dumps(
|
||||||
|
{
|
||||||
|
"content": content,
|
||||||
|
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
span.set_attribute("output", output_attr)
|
||||||
|
|
||||||
|
n_iter += 1
|
||||||
|
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
|
||||||
|
|
||||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||||
|
|
||||||
|
|
@ -713,19 +659,22 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
model_response=copy.deepcopy(message),
|
model_response=copy.deepcopy(message),
|
||||||
started_at=inference_start_time,
|
started_at=inference_start_time,
|
||||||
completed_at=datetime.now(),
|
completed_at=datetime.now().astimezone().isoformat(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if n_iter >= self.agent_config.max_infer_iters:
|
if n_iter >= self.agent_config.max_infer_iters:
|
||||||
log.info("Done with MAX iterations, exiting.")
|
logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.")
|
||||||
|
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||||
|
# Do not continue the tool call loop after this point
|
||||||
|
message.stop_reason = StopReason.end_of_turn
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
||||||
if stop_reason == StopReason.out_of_tokens:
|
if stop_reason == StopReason.out_of_tokens:
|
||||||
log.info("Out of token budget, exiting.")
|
logcat.info("agents", "out of token budget, exiting.")
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -739,10 +688,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
message.content = [message.content] + output_attachments
|
message.content = [message.content] + output_attachments
|
||||||
yield message
|
yield message
|
||||||
else:
|
else:
|
||||||
log.info(f"Partial message: {str(message)}")
|
logcat.debug(
|
||||||
|
"agents",
|
||||||
|
f"completion message with EOM (iter: {n_iter}): {str(message)}",
|
||||||
|
)
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
else:
|
else:
|
||||||
log.info(f"{str(message)}")
|
logcat.debug(
|
||||||
|
"agents",
|
||||||
|
f"completion message (iter: {n_iter}) from the model: {str(message)}",
|
||||||
|
)
|
||||||
# 1. Start the tool execution step and progress
|
# 1. Start the tool execution step and progress
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
|
@ -770,6 +725,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# If tool is a client tool, yield CompletionMessage and return
|
# If tool is a client tool, yield CompletionMessage and return
|
||||||
if tool_call.tool_name in client_tools:
|
if tool_call.tool_name in client_tools:
|
||||||
|
# NOTE: mark end_of_message to indicate to client that it may
|
||||||
|
# call the tool and continue the conversation with the tool's response.
|
||||||
|
message.stop_reason = StopReason.end_of_message
|
||||||
await self.storage.set_in_progress_tool_call_step(
|
await self.storage.set_in_progress_tool_call_step(
|
||||||
session_id,
|
session_id,
|
||||||
turn_id,
|
turn_id,
|
||||||
|
|
@ -778,7 +736,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
tool_calls=[tool_call],
|
tool_calls=[tool_call],
|
||||||
tool_responses=[],
|
tool_responses=[],
|
||||||
started_at=datetime.now(),
|
started_at=datetime.now().astimezone().isoformat(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
yield message
|
yield message
|
||||||
|
|
@ -795,7 +753,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
"input": message.model_dump_json(),
|
"input": message.model_dump_json(),
|
||||||
},
|
},
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now()
|
tool_execution_start_time = datetime.now().astimezone().isoformat()
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
tool_result = await execute_tool_call_maybe(
|
tool_result = await execute_tool_call_maybe(
|
||||||
self.tool_runtime_api,
|
self.tool_runtime_api,
|
||||||
|
|
@ -804,6 +762,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
toolgroup_args,
|
toolgroup_args,
|
||||||
tool_to_group,
|
tool_to_group,
|
||||||
)
|
)
|
||||||
|
if tool_result.content is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
|
||||||
|
)
|
||||||
result_messages = [
|
result_messages = [
|
||||||
ToolResponseMessage(
|
ToolResponseMessage(
|
||||||
call_id=tool_call.call_id,
|
call_id=tool_call.call_id,
|
||||||
|
|
@ -833,7 +795,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
started_at=tool_execution_start_time,
|
started_at=tool_execution_start_time,
|
||||||
completed_at=datetime.now(),
|
completed_at=datetime.now().astimezone().isoformat(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -841,8 +803,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||||
# but that needs a lot more refactoring of Tool code potentially
|
# but that needs a lot more refactoring of Tool code potentially
|
||||||
|
if (type(result_message.content) is str) and (
|
||||||
if out_attachment := _interpret_content_as_attachment(result_message.content):
|
out_attachment := _interpret_content_as_attachment(result_message.content)
|
||||||
|
):
|
||||||
# NOTE: when we push this message back to the model, the model may ignore the
|
# NOTE: when we push this message back to the model, the model may ignore the
|
||||||
# attached file path etc. since the model is trained to only provide a user message
|
# attached file path etc. since the model is trained to only provide a user message
|
||||||
# with the summary. We keep all generated attachments and then attach them to final message
|
# with the summary. We keep all generated attachments and then attach them to final message
|
||||||
|
|
@ -850,32 +813,24 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
input_messages = input_messages + [message, result_message]
|
input_messages = input_messages + [message, result_message]
|
||||||
|
|
||||||
n_iter += 1
|
|
||||||
|
|
||||||
async def _get_tool_defs(
|
async def _get_tool_defs(
|
||||||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||||
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||||
# Determine which tools to include
|
# Determine which tools to include
|
||||||
agent_config_toolgroups = set(
|
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
||||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
agent_config_toolgroups = []
|
||||||
for toolgroup in self.agent_config.toolgroups
|
for toolgroup in tool_groups_to_include:
|
||||||
)
|
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
|
||||||
toolgroups_for_turn_set = (
|
if name not in agent_config_toolgroups:
|
||||||
agent_config_toolgroups
|
agent_config_toolgroups.append(name)
|
||||||
if toolgroups_for_turn is None
|
|
||||||
else {
|
|
||||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
|
||||||
for toolgroup in toolgroups_for_turn
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_def_map = {}
|
tool_name_to_def = {}
|
||||||
tool_to_group = {}
|
tool_to_group = {}
|
||||||
|
|
||||||
for tool_def in self.agent_config.client_tools:
|
for tool_def in self.agent_config.client_tools:
|
||||||
if tool_def_map.get(tool_def.name, None):
|
if tool_name_to_def.get(tool_def.name, None):
|
||||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||||
tool_def_map[tool_def.name] = ToolDefinition(
|
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||||
tool_name=tool_def.name,
|
tool_name=tool_def.name,
|
||||||
description=tool_def.description,
|
description=tool_def.description,
|
||||||
parameters={
|
parameters={
|
||||||
|
|
@ -889,10 +844,19 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
tool_to_group[tool_def.name] = "__client_tools__"
|
tool_to_group[tool_def.name] = "__client_tools__"
|
||||||
for toolgroup_name in agent_config_toolgroups:
|
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||||
if toolgroup_name not in toolgroups_for_turn_set:
|
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||||
continue
|
|
||||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||||
|
if not tools.data:
|
||||||
|
available_tool_groups = ", ".join(
|
||||||
|
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
||||||
|
)
|
||||||
|
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
||||||
|
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||||
|
)
|
||||||
|
|
||||||
for tool_def in tools.data:
|
for tool_def in tools.data:
|
||||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
||||||
tool_name = tool_def.identifier
|
tool_name = tool_def.identifier
|
||||||
|
|
@ -902,31 +866,61 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
built_in_type = BuiltinTool(tool_name)
|
built_in_type = BuiltinTool(tool_name)
|
||||||
|
|
||||||
if tool_def_map.get(built_in_type, None):
|
if tool_name_to_def.get(built_in_type, None):
|
||||||
raise ValueError(f"Tool {built_in_type} already exists")
|
raise ValueError(f"Tool {built_in_type} already exists")
|
||||||
|
|
||||||
tool_def_map[built_in_type] = ToolDefinition(tool_name=built_in_type)
|
tool_name_to_def[built_in_type] = ToolDefinition(
|
||||||
|
tool_name=built_in_type,
|
||||||
|
description=tool_def.description,
|
||||||
|
parameters={
|
||||||
|
param.name: ToolParamDefinition(
|
||||||
|
param_type=param.parameter_type,
|
||||||
|
description=param.description,
|
||||||
|
required=param.required,
|
||||||
|
default=param.default,
|
||||||
|
)
|
||||||
|
for param in tool_def.parameters
|
||||||
|
},
|
||||||
|
)
|
||||||
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if tool_def_map.get(tool_def.identifier, None):
|
if tool_name_to_def.get(tool_def.identifier, None):
|
||||||
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||||
tool_def_map[tool_def.identifier] = ToolDefinition(
|
if tool_name in (None, tool_def.identifier):
|
||||||
tool_name=tool_def.identifier,
|
tool_name_to_def[tool_def.identifier] = ToolDefinition(
|
||||||
description=tool_def.description,
|
tool_name=tool_def.identifier,
|
||||||
parameters={
|
description=tool_def.description,
|
||||||
param.name: ToolParamDefinition(
|
parameters={
|
||||||
param_type=param.parameter_type,
|
param.name: ToolParamDefinition(
|
||||||
description=param.description,
|
param_type=param.parameter_type,
|
||||||
required=param.required,
|
description=param.description,
|
||||||
default=param.default,
|
required=param.required,
|
||||||
)
|
default=param.default,
|
||||||
for param in tool_def.parameters
|
)
|
||||||
},
|
for param in tool_def.parameters
|
||||||
)
|
},
|
||||||
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
)
|
||||||
|
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||||
|
|
||||||
return tool_def_map, tool_to_group
|
return list(tool_name_to_def.values()), tool_to_group
|
||||||
|
|
||||||
|
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
||||||
|
"""Parse a toolgroup name into its components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (tool_type, tool_group, tool_name)
|
||||||
|
"""
|
||||||
|
split_names = toolgroup_name_with_maybe_tool_name.split("/")
|
||||||
|
if len(split_names) == 2:
|
||||||
|
# e.g. "builtin::rag"
|
||||||
|
tool_group, tool_name = split_names
|
||||||
|
else:
|
||||||
|
tool_group, tool_name = split_names[0], None
|
||||||
|
return tool_group, tool_name
|
||||||
|
|
||||||
async def handle_documents(
|
async def handle_documents(
|
||||||
self,
|
self,
|
||||||
|
|
@ -935,8 +929,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
tool_defs: Dict[str, ToolDefinition],
|
tool_defs: Dict[str, ToolDefinition],
|
||||||
) -> None:
|
) -> None:
|
||||||
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
|
||||||
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
|
||||||
content_items = []
|
content_items = []
|
||||||
url_items = []
|
url_items = []
|
||||||
pattern = re.compile("^(https?://|file://|data:)")
|
pattern = re.compile("^(https?://|file://|data:)")
|
||||||
|
|
@ -1046,7 +1040,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
path = urlparse(uri).path
|
path = urlparse(uri).path
|
||||||
basename = os.path.basename(path)
|
basename = os.path.basename(path)
|
||||||
filepath = f"{tempdir}/{make_random_string() + basename}"
|
filepath = f"{tempdir}/{make_random_string() + basename}"
|
||||||
log.info(f"Downloading {url} -> {filepath}")
|
logcat.info("agents", f"Downloading {url} -> {filepath}")
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(uri)
|
r = await client.get(uri)
|
||||||
|
|
@ -1056,7 +1050,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported URL {url}")
|
raise ValueError(f"Unsupported URL {url}")
|
||||||
|
|
||||||
content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n'))
|
content.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return ToolResponseMessage(
|
return ToolResponseMessage(
|
||||||
call_id="",
|
call_id="",
|
||||||
|
|
@ -1076,22 +1074,23 @@ async def execute_tool_call_maybe(
|
||||||
group_name = tool_to_group.get(name, None)
|
group_name = tool_to_group.get(name, None)
|
||||||
if group_name is None:
|
if group_name is None:
|
||||||
raise ValueError(f"Tool {name} not found in any tool group")
|
raise ValueError(f"Tool {name} not found in any tool group")
|
||||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
|
||||||
tool_call_args = tool_call.arguments
|
|
||||||
tool_call_args.update(toolgroup_args.get(group_name, {}))
|
|
||||||
if isinstance(name, BuiltinTool):
|
if isinstance(name, BuiltinTool):
|
||||||
if name == BuiltinTool.brave_search:
|
if name == BuiltinTool.brave_search:
|
||||||
name = WEB_SEARCH_TOOL
|
name = WEB_SEARCH_TOOL
|
||||||
else:
|
else:
|
||||||
name = name.value
|
name = name.value
|
||||||
|
|
||||||
|
logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}")
|
||||||
result = await tool_runtime_api.invoke_tool(
|
result = await tool_runtime_api.invoke_tool(
|
||||||
tool_name=name,
|
tool_name=name,
|
||||||
kwargs=dict(
|
kwargs={
|
||||||
session_id=session_id,
|
"session_id": session_id,
|
||||||
**tool_call_args,
|
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||||
),
|
**tool_call.arguments,
|
||||||
|
**toolgroup_args.get(group_name, {}),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
logcat.debug("agents", f"tool call {name} completed with result: {result}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -140,7 +140,6 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
documents: Optional[List[Document]] = None,
|
documents: Optional[List[Document]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
allow_turn_resume: Optional[bool] = False,
|
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
|
@ -150,7 +149,6 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
toolgroups=toolgroups,
|
toolgroups=toolgroups,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
allow_turn_resume=allow_turn_resume,
|
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._create_agent_turn_streaming(request)
|
return self._create_agent_turn_streaming(request)
|
||||||
|
|
@ -194,17 +192,13 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
agent = await self.get_agent(agent_id)
|
||||||
turn = json.loads(turn)
|
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||||
turn = Turn(**turn)
|
|
||||||
return turn
|
return turn
|
||||||
|
|
||||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
|
||||||
turn = json.loads(turn)
|
for step in turn.steps:
|
||||||
turn = Turn(**turn)
|
|
||||||
steps = turn.steps
|
|
||||||
for step in steps:
|
|
||||||
if step.step_id == step_id:
|
if step.step_id == step_id:
|
||||||
return AgentStepResponse(step=step)
|
return AgentStepResponse(step=step)
|
||||||
raise ValueError(f"Provided step_id {step_id} could not be found")
|
raise ValueError(f"Provided step_id {step_id} could not be found")
|
||||||
|
|
@ -215,20 +209,18 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_ids: Optional[List[str]] = None,
|
turn_ids: Optional[List[str]] = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
|
agent = await self.get_agent(agent_id)
|
||||||
session = Session(**json.loads(session), turns=[])
|
session_info = await agent.storage.get_session_info(session_id)
|
||||||
turns = []
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
turns = await agent.storage.get_session_turns(session_id)
|
||||||
if turn_ids:
|
if turn_ids:
|
||||||
for turn_id in turn_ids:
|
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||||
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
|
||||||
turn = json.loads(turn)
|
|
||||||
turn = Turn(**turn)
|
|
||||||
turns.append(turn)
|
|
||||||
return Session(
|
return Session(
|
||||||
session_name=session.session_name,
|
session_name=session_info.session_name,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
turns=turns if turns else [],
|
turns=turns,
|
||||||
started_at=session.started_at,
|
started_at=session_info.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ log = logging.getLogger(__name__)
|
||||||
class AgentSessionInfo(BaseModel):
|
class AgentSessionInfo(BaseModel):
|
||||||
session_id: str
|
session_id: str
|
||||||
session_name: str
|
session_name: str
|
||||||
|
# TODO: is this used anywhere?
|
||||||
vector_db_id: Optional[str] = None
|
vector_db_id: Optional[str] = None
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
|
|
@ -85,6 +86,14 @@ class AgentPersistence:
|
||||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
)
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
return Turn(**json.loads(value))
|
||||||
|
|
||||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
|
@ -96,3 +105,15 @@ class AgentPersistence:
|
||||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
)
|
)
|
||||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||||
|
|
||||||
|
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
value=str(num_infer_iters),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||||
|
value = await self.kvstore.get(
|
||||||
|
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
|
)
|
||||||
|
return int(value) if value else None
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
@ -82,23 +83,22 @@ class MetaReferenceEvalImpl(
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
task_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
task_def = self.benchmarks[benchmark_id]
|
task_def = self.benchmarks[benchmark_id]
|
||||||
dataset_id = task_def.dataset_id
|
dataset_id = task_def.dataset_id
|
||||||
candidate = task_config.eval_candidate
|
|
||||||
scoring_functions = task_def.scoring_functions
|
scoring_functions = task_def.scoring_functions
|
||||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
rows_in_page=(-1 if task_config.num_examples is None else task_config.num_examples),
|
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
|
||||||
)
|
)
|
||||||
res = await self.evaluate_rows(
|
res = await self.evaluate_rows(
|
||||||
benchmark_id=benchmark_id,
|
benchmark_id=benchmark_id,
|
||||||
input_rows=all_rows.rows,
|
input_rows=all_rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
task_config=task_config,
|
benchmark_config=benchmark_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: currently needs to wait for generation before returning
|
# TODO: currently needs to wait for generation before returning
|
||||||
|
|
@ -108,16 +108,16 @@ class MetaReferenceEvalImpl(
|
||||||
return Job(job_id=job_id)
|
return Job(job_id=job_id)
|
||||||
|
|
||||||
async def _run_agent_generation(
|
async def _run_agent_generation(
|
||||||
self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig
|
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
candidate = task_config.eval_candidate
|
candidate = benchmark_config.eval_candidate
|
||||||
create_response = await self.agents_api.create_agent(candidate.config)
|
create_response = await self.agents_api.create_agent(candidate.config)
|
||||||
agent_id = create_response.agent_id
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
generations = []
|
generations = []
|
||||||
for i, x in tqdm(enumerate(input_rows)):
|
for i, x in tqdm(enumerate(input_rows)):
|
||||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||||
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
|
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||||
input_messages = [UserMessage(**x) for x in input_messages]
|
input_messages = [UserMessage(**x) for x in input_messages]
|
||||||
|
|
||||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||||
|
|
@ -151,15 +151,15 @@ class MetaReferenceEvalImpl(
|
||||||
return generations
|
return generations
|
||||||
|
|
||||||
async def _run_model_generation(
|
async def _run_model_generation(
|
||||||
self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig
|
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
candidate = task_config.eval_candidate
|
candidate = benchmark_config.eval_candidate
|
||||||
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||||
|
|
||||||
generations = []
|
generations = []
|
||||||
for x in tqdm(input_rows):
|
for x in tqdm(input_rows):
|
||||||
if ColumnName.completion_input.value in x:
|
if ColumnName.completion_input.value in x:
|
||||||
input_content = eval(str(x[ColumnName.completion_input.value]))
|
input_content = json.loads(x[ColumnName.completion_input.value])
|
||||||
response = await self.inference_api.completion(
|
response = await self.inference_api.completion(
|
||||||
model=candidate.model,
|
model=candidate.model,
|
||||||
content=input_content,
|
content=input_content,
|
||||||
|
|
@ -167,9 +167,8 @@ class MetaReferenceEvalImpl(
|
||||||
)
|
)
|
||||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||||
elif ColumnName.chat_completion_input.value in x:
|
elif ColumnName.chat_completion_input.value in x:
|
||||||
chat_completion_input_str = str(x[ColumnName.chat_completion_input.value])
|
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||||
input_messages = eval(chat_completion_input_str)
|
input_messages = [UserMessage(**x) for x in chat_completion_input_json]
|
||||||
input_messages = [UserMessage(**x) for x in input_messages]
|
|
||||||
messages = []
|
messages = []
|
||||||
if candidate.system_message:
|
if candidate.system_message:
|
||||||
messages.append(candidate.system_message)
|
messages.append(candidate.system_message)
|
||||||
|
|
@ -190,13 +189,13 @@ class MetaReferenceEvalImpl(
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
task_config: BenchmarkConfig,
|
benchmark_config: BenchmarkConfig,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
candidate = task_config.eval_candidate
|
candidate = benchmark_config.eval_candidate
|
||||||
if candidate.type == "agent":
|
if candidate.type == "agent":
|
||||||
generations = await self._run_agent_generation(input_rows, task_config)
|
generations = await self._run_agent_generation(input_rows, benchmark_config)
|
||||||
elif candidate.type == "model":
|
elif candidate.type == "model":
|
||||||
generations = await self._run_model_generation(input_rows, task_config)
|
generations = await self._run_model_generation(input_rows, benchmark_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
||||||
|
|
||||||
|
|
@ -205,9 +204,9 @@ class MetaReferenceEvalImpl(
|
||||||
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
if task_config.scoring_params is not None:
|
if benchmark_config.scoring_params is not None:
|
||||||
scoring_functions_dict = {
|
scoring_functions_dict = {
|
||||||
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
|
||||||
for scoring_fn_id in scoring_functions
|
for scoring_fn_id in scoring_functions
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResult(BaseModel):
|
||||||
|
token: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[List[float]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def model_checkpoint_dir(model_id) -> str:
|
||||||
|
checkpoint_dir = Path(model_local_dir(model_id))
|
||||||
|
|
||||||
|
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||||
|
if not any(p.exists() for p in paths):
|
||||||
|
checkpoint_dir = checkpoint_dir / "original"
|
||||||
|
|
||||||
|
assert checkpoint_dir.exists(), (
|
||||||
|
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||||
|
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
|
||||||
|
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
|
||||||
|
)
|
||||||
|
return str(checkpoint_dir)
|
||||||
|
|
@ -55,7 +55,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .llama3.generation import Llama3
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -83,7 +83,7 @@ class MetaReferenceInferenceImpl(
|
||||||
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
|
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
else:
|
else:
|
||||||
self.generator = Llama.build(self.config, model_id, llama_model)
|
self.generator = Llama3.build(self.config, model_id, llama_model)
|
||||||
|
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.llama_model = llama_model
|
self.llama_model = llama_model
|
||||||
|
|
@ -111,7 +111,7 @@ class MetaReferenceInferenceImpl(
|
||||||
)
|
)
|
||||||
if llama_model is None:
|
if llama_model is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Please make sure your llama_model in model metadata or model identifier is in llama-models SKU list"
|
"Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_registry_helper = ModelRegistryHelper(
|
self.model_registry_helper = ModelRegistryHelper(
|
||||||
|
|
@ -208,7 +208,6 @@ class MetaReferenceInferenceImpl(
|
||||||
logprobs = []
|
logprobs = []
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
|
||||||
for token_result in self.generator.completion(request):
|
for token_result in self.generator.completion(request):
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
if token_result.text == "<|eot_id|>":
|
if token_result.text == "<|eot_id|>":
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue