Merge remote-tracking branch 'origin/main' into support_more_data_format

This commit is contained in:
Botao Chen 2025-01-13 20:36:14 -08:00
commit a3b1c3438b
171 changed files with 14529 additions and 5612 deletions

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721

View file

@ -0,0 +1,232 @@
name: Publish Python 🐍 distribution 📦 to TestPyPI
on:
workflow_dispatch: # Keep manual trigger
inputs:
version:
description: 'Version number (e.g. 0.0.63.dev20250111)'
required: true
type: string
schedule:
- cron: "0 0 * * *" # Run every day at midnight
jobs:
trigger-client-and-models-build:
name: Trigger llama-stack-client and llama-models build
runs-on: ubuntu-latest
outputs:
version: ${{ steps.version.outputs.version }}
client_run_id: ${{ steps.trigger-client.outputs.workflow_id }}
model_run_id: ${{ steps.trigger-models.outputs.workflow_id }}
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get date
id: date
run: echo "date=$(date +'%Y%m%d')" >> $GITHUB_OUTPUT
- name: Compute version based on dispatch event
id: version
run: |
# Read base version from pyproject.toml
version=$(sed -n 's/.*version="\([^"]*\)".*/\1/p' setup.py)
if [ "${{ github.event_name }}" = "schedule" ]; then
echo "version=${version}.dev${{ steps.date.outputs.date }}" >> $GITHUB_OUTPUT
elif [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT
else
echo "version=${version}.dev$(shuf -i 10000000-99999999 -n 1)" >> $GITHUB_OUTPUT
fi
- name: Trigger llama-stack-client workflow
id: trigger-client
run: |
response=$(curl -X POST https://api.github.com/repos/meta-llama/llama-stack-client-python/dispatches \
-H 'Accept: application/vnd.github.everest-preview+json' \
-H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
--data "{\"event_type\": \"build-client-package\", \"client_payload\": {\"source\": \"llama-stack-nightly\", \"version\": \"${{ steps.version.outputs.version }}\"}}" \
-w "\n%{http_code}")
http_code=$(echo "$response" | tail -n1)
if [ "$http_code" != "204" ]; then
echo "Failed to trigger client workflow"
exit 1
fi
# Get the run ID of the triggered workflow
sleep 5 # Wait for workflow to be created
run_id=$(curl -s -H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
"https://api.github.com/repos/meta-llama/llama-stack-client-python/actions/runs?event=repository_dispatch" \
| jq '.workflow_runs[0].id')
echo "workflow_id=$run_id" >> $GITHUB_OUTPUT
- name: Trigger llama-models workflow
id: trigger-models
run: |
response=$(curl -X POST https://api.github.com/repos/meta-llama/llama-models/dispatches \
-H 'Accept: application/vnd.github.everest-preview+json' \
-H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
--data "{\"event_type\": \"build-models-package\", \"client_payload\": {\"source\": \"llama-stack-nightly\", \"version\": \"${{ steps.version.outputs.version }}\"}}" \
-w "\n%{http_code}")
http_code=$(echo "$response" | tail -n1)
if [ "$http_code" != "204" ]; then
echo "Failed to trigger models workflow"
exit 1
fi
# Get the run ID of the triggered workflow
sleep 5 # Wait for workflow to be created
run_id=$(curl -s -H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
"https://api.github.com/repos/meta-llama/llama-models/actions/runs?event=repository_dispatch" \
| jq '.workflow_runs[0].id')
echo "workflow_id=$run_id" >> $GITHUB_OUTPUT
wait-for-workflows:
name: Wait for triggered workflows
needs: trigger-client-and-models-build
runs-on: ubuntu-latest
steps:
- name: Wait for client workflow
run: |
while true; do
status=$(curl -s -H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
"https://api.github.com/repos/meta-llama/llama-stack-client-python/actions/runs/${{ needs.trigger-client-and-models-build.outputs.client_run_id }}" \
| jq -r '.status')
conclusion=$(curl -s -H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
"https://api.github.com/repos/meta-llama/llama-stack-client-python/actions/runs/${{ needs.trigger-client-and-models-build.outputs.client_run_id }}" \
| jq -r '.conclusion')
echo "llama-stack-client-python workflow status: $status, conclusion: $conclusion"
if [ "$status" = "completed" ]; then
if [ "$conclusion" != "success" ]; then
echo "llama-stack-client-python workflow failed"
exit 1
fi
break
fi
sleep 10
done
- name: Wait for models workflow
run: |
while true; do
status=$(curl -s -H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
"https://api.github.com/repos/meta-llama/llama-models/actions/runs/${{ needs.trigger-client-and-models-build.outputs.model_run_id }}" \
| jq -r '.status')
conclusion=$(curl -s -H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
"https://api.github.com/repos/meta-llama/llama-models/actions/runs/${{ needs.trigger-client-and-models-build.outputs.model_run_id }}" \
| jq -r '.conclusion')
echo "llama-models workflow status: $status, conclusion: $conclusion"
if [ "$status" = "completed" ]; then
if [ "$conclusion" != "success" ]; then
echo "llama-models workflow failed"
exit 1
fi
break
fi
sleep 10
done
build:
name: Build distribution 📦
needs:
- wait-for-workflows
- trigger-client-and-models-build
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Get date
id: date
run: echo "date=$(date +'%Y%m%d')" >> $GITHUB_OUTPUT
- name: Update version for nightly
run: |
sed -i 's/version="\([^"]*\)"/version="${{ needs.trigger-client-and-models-build.outputs.version }}"/' setup.py
sed -i 's/llama-stack-client>=\([^"]*\)/llama-stack-client==${{ needs.trigger-client-and-models-build.outputs.version }}/' requirements.txt
sed -i 's/llama-models>=\([^"]*\)/llama-models==${{ needs.trigger-client-and-models-build.outputs.version }}/' requirements.txt
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install pypa/build
run: >-
python3 -m
pip install
build
--user
- name: Build a binary wheel and a source tarball
run: python3 -m build
- name: Store the distribution packages
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: dist/
publish-to-testpypi:
name: Publish Python 🐍 distribution 📦 to TestPyPI
needs:
- build
runs-on: ubuntu-latest
environment:
name: testrelease
url: https://test.pypi.org/p/llama-stack
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
test-published-package:
name: Test published package
needs:
- publish-to-testpypi
- trigger-client-and-models-build
runs-on: ubuntu-latest
steps:
- name: Install the package
run: |
max_attempts=6
attempt=1
while [ $attempt -le $max_attempts ]; do
echo "Attempt $attempt of $max_attempts to install package..."
if pip install --no-cache --index-url https://pypi.org/simple/ --extra-index-url https://test.pypi.org/simple/ llama-stack==${{ needs.trigger-client-and-models-build.outputs.version }}; then
echo "Package installed successfully"
break
fi
if [ $attempt -ge $max_attempts ]; then
echo "Failed to install package after $max_attempts attempts"
exit 1
fi
attempt=$((attempt + 1))
sleep 10
done
- name: Test the package versions
run: |
pip list | grep llama_
- name: Test CLI commands
run: |
llama model list
llama stack build --list-templates
llama model prompt-format -m Llama3.2-11B-Vision-Instruct
llama stack list-apis
llama stack list-providers inference
llama stack list-providers telemetry
# TODO: add trigger for integration test workflow & docker builds

View file

@ -99,7 +99,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) |
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) |

View file

@ -1,9 +1,42 @@
{
"bedrock": [
"hf-serverless": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"together": [
"aiosqlite",
"autoevals",
"blobfile",
"boto3",
"chardet",
"chromadb-client",
"datasets",
@ -22,6 +55,71 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"together",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"vllm-gpu": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"vllm",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -54,6 +152,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -63,7 +162,7 @@
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"hf-endpoint": [
"tgi": [
"aiohttp",
"aiosqlite",
"autoevals",
@ -87,6 +186,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -96,11 +196,11 @@
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"hf-serverless": [
"aiohttp",
"bedrock": [
"aiosqlite",
"autoevals",
"blobfile",
"boto3",
"chardet",
"chromadb-client",
"datasets",
@ -108,7 +208,6 @@
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"matplotlib",
"nltk",
"numpy",
@ -120,6 +219,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -154,6 +254,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentence-transformers",
@ -193,6 +294,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentence-transformers",
@ -207,6 +309,35 @@
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"cerebras": [
"aiosqlite",
"blobfile",
"cerebras_cloud_sdk",
"chardet",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"ollama": [
"aiohttp",
"aiosqlite",
@ -231,6 +362,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
@ -240,7 +372,7 @@
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"tgi": [
"hf-endpoint": [
"aiohttp",
"aiosqlite",
"autoevals",
@ -264,127 +396,7 @@
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"together": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"together",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"vllm-gpu": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"vllm",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"cerebras": [
"aiosqlite",
"blobfile",
"cerebras_cloud_sdk",
"chardet",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",

View file

@ -85,7 +85,7 @@ services:
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
ports:
- "${LLAMASTACK_PORT:-5001}:${LLAMASTACK_PORT:-5001}"
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
# Hack: wait for vLLM server to start before starting docker
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 5001"
deploy:

File diff suppressed because one or more lines are too long

View file

@ -486,13 +486,22 @@ class Generator:
parameters = path_parameters + query_parameters
parameters += [
Parameter(
name="X-LlamaStack-ProviderData",
name="X-LlamaStack-Provider-Data",
in_=ParameterLocation.Header,
description="JSON-encoded provider data which will be made available to the adapter servicing the API",
required=False,
schema=self.schema_builder.classdef_to_ref(str),
)
]
parameters += [
Parameter(
name="X-LlamaStack-Client-Version",
in_=ParameterLocation.Header,
description="Version of the client making the request. This is used to ensure that the client and server are compatible.",
required=False,
schema=self.schema_builder.classdef_to_ref(str),
)
]
# data passed in payload
if op.request_params:

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -8,10 +8,6 @@ building_distro
configuration
```
<!-- self_hosted_distro/index -->
<!-- remote_hosted_distro/index -->
<!-- ondevice_distro/index -->
You can instantiate a Llama Stack in one of the following ways:
- **As a Library**: this is the simplest, especially if you are using an external inference service. See [Using Llama Stack as a Library](importing_as_library)
- **Docker**: we provide a number of pre-built Docker containers so you can start a Llama Stack server instantly. You can also build your own custom Docker container.
@ -30,11 +26,15 @@ If so, we suggest:
- {dockerhub}`distribution-ollama` ([Guide](self_hosted_distro/ollama))
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
- {dockerhub}`distribution-together` ([Guide](remote_hosted_distro/index))
- {dockerhub}`distribution-fireworks` ([Guide](remote_hosted_distro/index))
- {dockerhub}`distribution-together` ([Guide](self_hosted_distro/together))
- {dockerhub}`distribution-fireworks` ([Guide](self_hosted_distro/fireworks))
- **Do you want to run Llama Stack inference on your iOS / Android device** If so, we suggest:
- [iOS SDK](ondevice_distro/ios_sdk)
- [Android](ondevice_distro/android_sdk)
- **Do you want a hosted Llama Stack endpoint?** If so, we suggest:
- [Remote-Hosted Llama Stack Endpoints](remote_hosted_distro/index)
You can also build your own [custom distribution](building_distro).

View file

@ -19,6 +19,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
| safety | `remote::bedrock` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
@ -26,7 +27,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
### Models

View file

@ -1,5 +1,15 @@
---
orphan: true
---
# Cerebras Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-cerebras` distribution consists of the following provider configurations.
| API | Provider(s) |
@ -9,13 +19,14 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
| memory | `inline::meta-reference` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
### Models

View file

@ -22,28 +22,30 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``)
### Models
The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (fireworks/llama-v3p1-8b-instruct)`
- `meta-llama/Llama-3.1-70B-Instruct (fireworks/llama-v3p1-70b-instruct)`
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (fireworks/llama-v3p1-405b-instruct)`
- `meta-llama/Llama-3.2-1B-Instruct (fireworks/llama-v3p2-1b-instruct)`
- `meta-llama/Llama-3.2-3B-Instruct (fireworks/llama-v3p2-3b-instruct)`
- `meta-llama/Llama-3.2-11B-Vision-Instruct (fireworks/llama-v3p2-11b-vision-instruct)`
- `meta-llama/Llama-3.2-90B-Vision-Instruct (fireworks/llama-v3p2-90b-vision-instruct)`
- `meta-llama/Llama-Guard-3-8B (fireworks/llama-guard-3-8b)`
- `meta-llama/Llama-Guard-3-11B-Vision (fireworks/llama-guard-3-11b-vision)`
- `meta-llama/Llama-3.1-8B-Instruct (accounts/fireworks/models/llama-v3p1-8b-instruct)`
- `meta-llama/Llama-3.1-70B-Instruct (accounts/fireworks/models/llama-v3p1-70b-instruct)`
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (accounts/fireworks/models/llama-v3p1-405b-instruct)`
- `meta-llama/Llama-3.2-1B-Instruct (accounts/fireworks/models/llama-v3p2-1b-instruct)`
- `meta-llama/Llama-3.2-3B-Instruct (accounts/fireworks/models/llama-v3p2-3b-instruct)`
- `meta-llama/Llama-3.2-11B-Vision-Instruct (accounts/fireworks/models/llama-v3p2-11b-vision-instruct)`
- `meta-llama/Llama-3.2-90B-Vision-Instruct (accounts/fireworks/models/llama-v3p2-90b-vision-instruct)`
- `meta-llama/Llama-3.3-70B-Instruct (accounts/fireworks/models/llama-v3p3-70b-instruct)`
- `meta-llama/Llama-Guard-3-8B (accounts/fireworks/models/llama-guard-3-8b)`
- `meta-llama/Llama-Guard-3-11B-Vision (accounts/fireworks/models/llama-guard-3-11b-vision)`
### Prerequisite: API Keys

View file

@ -22,6 +22,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
@ -30,7 +31,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)

View file

@ -22,6 +22,7 @@ The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
@ -32,7 +33,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)

View file

@ -22,13 +22,14 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)

View file

@ -18,6 +18,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
@ -26,9 +27,9 @@ You can use this distribution if you have GPUs and want to run an independent vL
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100}/v1`)
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100/v1`)
- `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`)
- `SAFETY_VLLM_URL`: URL of the vLLM server with the safety model (default: `http://host.docker.internal:5101/v1`)
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)

View file

@ -23,6 +23,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
You can use this distribution if you have GPUs and want to run an independent TGI server container for running inference.
@ -31,7 +32,7 @@ You can use this distribution if you have GPUs and want to run an independent TG
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080}/v1`)
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)

View file

@ -22,13 +22,14 @@ The `llamastack/distribution-together` distribution consists of the following pr
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `TOGETHER_API_KEY`: Together.AI API Key (default: ``)
### Models
@ -41,6 +42,7 @@ The following models are available by default:
- `meta-llama/Llama-3.2-3B-Instruct`
- `meta-llama/Llama-3.2-11B-Vision-Instruct`
- `meta-llama/Llama-3.2-90B-Vision-Instruct`
- `meta-llama/Llama-3.3-70B-Instruct`
- `meta-llama/Llama-Guard-3-8B`
- `meta-llama/Llama-Guard-3-11B-Vision`

View file

@ -97,20 +97,20 @@ To download models, you can use the llama download command.
#### Downloading from [Meta](https://llama.meta.com/llama-downloads/)
Here is an example download command to get the 3B-Instruct/11B-Vision-Instruct model. You will need META_URL which can be obtained from [here](https://llama.meta.com/docs/getting_the_models/meta/)
Here is an example download command to get the 3B-Instruct/11B-Vision-Instruct model. You will need META_URL which can be obtained from [here](https://llama.meta.com/docs/getting_the_models/meta/). Note: You need to quote the META_URL
Download the required checkpoints using the following commands:
```bash
# download the 8B model, this can be run on a single GPU
llama download --source meta --model-id Llama3.2-3B-Instruct --meta-url META_URL
llama download --source meta --model-id Llama3.2-3B-Instruct --meta-url 'META_URL'
# you can also get the 70B model, this will require 8 GPUs however
llama download --source meta --model-id Llama3.2-11B-Vision-Instruct --meta-url META_URL
llama download --source meta --model-id Llama3.2-11B-Vision-Instruct --meta-url 'META_URL'
# llama-agents have safety enabled by default. For this, you will need
# safety models -- Llama-Guard and Prompt-Guard
llama download --source meta --model-id Prompt-Guard-86M --meta-url META_URL
llama download --source meta --model-id Llama-Guard-3-1B --meta-url META_URL
llama download --source meta --model-id Prompt-Guard-86M --meta-url 'META_URL'
llama download --source meta --model-id Llama-Guard-3-1B --meta-url 'META_URL'
```
#### Downloading from [Hugging Face](https://huggingface.co/meta-llama)

View file

@ -89,7 +89,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
```
...
Build Successful! Next steps:
1. Set the environment variables: LLAMASTACK_PORT, OLLAMA_URL, INFERENCE_MODEL, SAFETY_MODEL
1. Set the environment variables: LLAMA_STACK_PORT, OLLAMA_URL, INFERENCE_MODEL, SAFETY_MODEL
2. `llama stack run /Users/<username>/.llama/distributions/llamastack-ollama/ollama-run.yaml
```

View file

@ -18,15 +18,11 @@ from typing import (
Union,
)
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.schema_utils import json_schema_type, webmethod
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
@ -40,166 +36,18 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@json_schema_type
class Attachment(BaseModel):
content: InterleavedContent | URL
mime_type: str
class AgentTool(Enum):
brave_search = "brave_search"
wolfram_alpha = "wolfram_alpha"
photogen = "photogen"
code_interpreter = "code_interpreter"
function_call = "function_call"
memory = "memory"
class ToolDefinitionCommon(BaseModel):
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
class SearchEngineType(Enum):
bing = "bing"
brave = "brave"
tavily = "tavily"
@json_schema_type
class SearchToolDefinition(ToolDefinitionCommon):
# NOTE: brave_search is just a placeholder since model always uses
# brave_search as tool call name
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
api_key: str
engine: SearchEngineType = SearchEngineType.brave
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class WolframAlphaToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
api_key: str
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class PhotogenToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
enable_inline_code_execution: bool = True
remote_execution: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class FunctionCallToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
function_name: str
description: str
parameters: Dict[str, ToolParamDefinition]
remote_execution: Optional[RestAPIExecutionConfig] = None
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["vector"] = "vector"
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyvalue"] = "keyvalue"
keys: List[str] # what keys to focus on
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyword"] = "keyword"
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["graph"] = "graph"
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
AgentVectorMemoryBankConfig,
AgentKeyValueMemoryBankConfig,
AgentKeywordMemoryBankConfig,
AgentGraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
@json_schema_type
class MemoryToolDefinition(ToolDefinitionCommon):
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 10
AgentToolDefinition = Annotated[
Union[
SearchToolDefinition,
WolframAlphaToolDefinition,
PhotogenToolDefinition,
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
MemoryToolDefinition,
],
Field(discriminator="type"),
]
class Document(BaseModel):
content: InterleavedContent | URL
mime_type: str
class StepCommon(BaseModel):
@ -289,13 +137,27 @@ class Session(BaseModel):
memory_bank: Optional[MemoryBank] = None
class AgentToolGroupWithArgs(BaseModel):
name: str
args: Dict[str, Any]
AgentToolGroup = register_schema(
Union[
str,
AgentToolGroupWithArgs,
],
name="AgentTool",
)
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
@ -340,6 +202,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
AgentTurnResponseEventType.step_complete.value
)
step_type: StepType
step_id: str
step_details: Step
@ -413,7 +276,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
ToolResponseMessage,
]
]
attachments: Optional[List[Attachment]] = None
documents: Optional[List[Document]] = None
toolgroups: Optional[List[AgentToolGroup]] = None
stream: Optional[bool] = False
@ -450,8 +315,9 @@ class Agents(Protocol):
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")

View file

@ -7,7 +7,6 @@
from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.inference import (
@ -44,9 +43,7 @@ class BatchChatCompletionRequest(BaseModel):
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
logprobs: Optional[LogProbConfig] = None
@ -75,6 +72,6 @@ class BatchInference(Protocol):
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ...

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from enum import Enum
from typing import (
Any,
AsyncIterator,
@ -26,16 +25,12 @@ from llama_models.llama3.api.datatypes import (
ToolDefinition,
ToolPromptFormat,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -87,7 +82,7 @@ class SystemMessage(BaseModel):
@json_schema_type
class ToolResponseMessage(BaseModel):
role: Literal["ipython"] = "ipython"
role: Literal["tool"] = "tool"
# it was nice to re-use the ToolResponse type, but having all messages
# have a `content` type makes things nicer too
call_id: str
@ -256,9 +251,7 @@ class ChatCompletionRequest(BaseModel):
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
@ -289,9 +282,7 @@ class BatchChatCompletionRequest(BaseModel):
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
logprobs: Optional[LogProbConfig] = None
@ -334,7 +325,7 @@ class Inference(Protocol):
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,

View file

@ -29,6 +29,11 @@ class HealthInfo(BaseModel):
# TODO: add a provider level status
@json_schema_type
class VersionInfo(BaseModel):
version: str
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/providers/list", method="GET")
@ -39,3 +44,6 @@ class Inspect(Protocol):
@webmethod(route="/health", method="GET")
async def health(self) -> HealthInfo: ...
@webmethod(route="/version", method="GET")
async def version(self) -> VersionInfo: ...

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
@ -21,59 +21,48 @@ class ToolParameter(BaseModel):
name: str
parameter_type: str
description: str
required: bool = Field(default=True)
default: Optional[Any] = None
@json_schema_type
class ToolHost(Enum):
distribution = "distribution"
client = "client"
model_context_protocol = "model_context_protocol"
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
tool_group: str
toolgroup_id: str
tool_host: ToolHost
description: str
parameters: List[ToolParameter]
provider_id: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
@json_schema_type
class ToolDef(BaseModel):
name: str
description: str
parameters: List[ToolParameter]
metadata: Dict[str, Any]
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
description: Optional[str] = None
parameters: Optional[List[ToolParameter]] = None
metadata: Optional[Dict[str, Any]] = None
@json_schema_type
class MCPToolGroupDef(BaseModel):
"""
A tool group that is defined by in a model context protocol server.
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
"""
type: Literal["model_context_protocol"] = "model_context_protocol"
endpoint: URL
class ToolGroupInput(BaseModel):
toolgroup_id: str
provider_id: str
args: Optional[Dict[str, Any]] = None
mcp_endpoint: Optional[URL] = None
@json_schema_type
class UserDefinedToolGroupDef(BaseModel):
type: Literal["user_defined"] = "user_defined"
tools: List[ToolDef]
ToolGroupDef = register_schema(
Annotated[
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
],
name="ToolGroup",
)
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
mcp_endpoint: Optional[URL] = None
args: Optional[Dict[str, Any]] = None
@json_schema_type
@ -85,6 +74,7 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ...
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
@runtime_checkable
@ -93,9 +83,10 @@ class ToolGroups(Protocol):
@webmethod(route="/toolgroups/register", method="POST")
async def register_tool_group(
self,
tool_group_id: str,
tool_group: ToolGroupDef,
provider_id: Optional[str] = None,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
"""Register a tool group"""
...
@ -103,7 +94,7 @@ class ToolGroups(Protocol):
@webmethod(route="/toolgroups/get", method="GET")
async def get_tool_group(
self,
tool_group_id: str,
toolgroup_id: str,
) -> ToolGroup: ...
@webmethod(route="/toolgroups/list", method="GET")
@ -130,8 +121,11 @@ class ToolGroups(Protocol):
class ToolRuntime(Protocol):
tool_store: ToolStore
@webmethod(route="/tool-runtime/discover", method="POST")
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(

View file

@ -43,7 +43,7 @@ class ModelPromptFormat(Subcommand):
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
import pkg_resources
import importlib.resources
# Only Llama 3.1 and 3.2 are supported
supported_model_ids = [
@ -64,25 +64,26 @@ class ModelPromptFormat(Subcommand):
f"{model_id} is not a valid Model. Choose one from --\n {model_str}"
)
llama_3_1_file = pkg_resources.resource_filename(
"llama_models", "llama3_1/prompt_format.md"
llama_3_1_file = (
importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
)
llama_3_2_text_file = pkg_resources.resource_filename(
"llama_models", "llama3_2/text_prompt_format.md"
llama_3_2_text_file = (
importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
)
llama_3_2_vision_file = pkg_resources.resource_filename(
"llama_models", "llama3_2/vision_prompt_format.md"
llama_3_2_vision_file = (
importlib.resources.files("llama_models")
/ "llama3_2/vision_prompt_format.md"
)
if model_family(model_id) == ModelFamily.llama3_1:
with open(llama_3_1_file, "r") as f:
content = f.read()
with importlib.resources.as_file(llama_3_1_file) as f:
content = f.open("r").read()
elif model_family(model_id) == ModelFamily.llama3_2:
if is_multimodal(model_id):
with open(llama_3_2_vision_file, "r") as f:
content = f.read()
with importlib.resources.as_file(llama_3_2_vision_file) as f:
content = f.open("r").read()
else:
with open(llama_3_2_text_file, "r") as f:
content = f.read()
with importlib.resources.as_file(llama_3_2_text_file) as f:
content = f.open("r").read()
render_markdown_to_pager(content)

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import importlib.resources
import os
import shutil
from functools import lru_cache
from pathlib import Path
from typing import List, Optional
import pkg_resources
from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import (
@ -290,13 +291,12 @@ class StackBuild(Subcommand):
if template_name:
# copy run.yaml from template to build_dir instead of generating it again
template_path = pkg_resources.resource_filename(
"llama_stack", f"templates/{template_name}/run.yaml"
template_path = (
importlib.resources.files("llama_stack")
/ f"templates/{template_name}/run.yaml"
)
os.makedirs(build_dir, exist_ok=True)
run_config_file = build_dir / f"{build_config.name}-run.yaml"
shutil.copy(template_path, run_config_file)
with importlib.resources.as_file(template_path) as path:
shutil.copy(path, run_config_file)
# Find all ${env.VARIABLE} patterns
cprint("Build Successful!", color="green")
else:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import argparse
import os
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
@ -34,7 +35,7 @@ class StackRun(Subcommand):
"--port",
type=int,
help="Port to run the server on. Defaults to 5000",
default=5000,
default=int(os.getenv("LLAMA_STACK_PORT", 5000)),
)
self.parser.add_argument(
"--disable-ipv6",
@ -51,7 +52,8 @@ class StackRun(Subcommand):
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import pkg_resources
import importlib.resources
import yaml
from llama_stack.distribution.build import ImageType
@ -106,15 +108,15 @@ class StackRun(Subcommand):
config = parse_and_maybe_upgrade_config(config_dict)
if config.docker_image:
script = pkg_resources.resource_filename(
"llama_stack",
"distribution/start_container.sh",
script = (
importlib.resources.files("llama_stack")
/ "distribution/start_container.sh"
)
run_args = [script, config.docker_image]
else:
script = pkg_resources.resource_filename(
"llama_stack",
"distribution/start_conda_env.sh",
script = (
importlib.resources.files("llama_stack")
/ "distribution/start_conda_env.sh"
)
run_args = [
script,

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import argparse
from importlib.metadata import version
from llama_stack.cli.subcommand import Subcommand
@ -24,6 +25,12 @@ class StackParser(Subcommand):
description="Operations for the Llama Stack / Distributions",
)
self.parser.add_argument(
"--version",
action="version",
version=f"{version('llama-stack')}",
)
subparsers = self.parser.add_subparsers(title="stack_subcommands")
# Add sub-commands

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib.resources
import logging
from enum import Enum
from pathlib import Path
from typing import Dict, List
import pkg_resources
from pydantic import BaseModel
from termcolor import cprint
@ -111,8 +111,8 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename(
"llama_stack", "distribution/build_container.sh"
script = (
importlib.resources.files("llama_stack") / "distribution/build_container.sh"
)
args = [
script,
@ -123,8 +123,8 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
" ".join(normal_deps),
]
elif build_config.image_type == ImageType.conda.value:
script = pkg_resources.resource_filename(
"llama_stack", "distribution/build_conda_env.sh"
script = (
importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh"
)
args = [
script,
@ -133,9 +133,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
" ".join(normal_deps),
]
elif build_config.image_type == ImageType.venv.value:
script = pkg_resources.resource_filename(
"llama_stack", "distribution/build_venv.sh"
)
script = importlib.resources.files("llama_stack") / "distribution/build_venv.sh"
args = [
script,
build_config.name,

View file

@ -51,7 +51,19 @@ add_to_docker() {
fi
}
add_to_docker <<EOF
# Update and install UBI9 components if UBI9 base image is used
if [[ $docker_base == *"registry.access.redhat.com/ubi9"* ]]; then
add_to_docker <<EOF
FROM $docker_base
WORKDIR /app
RUN microdnf -y update && microdnf install -y iputils net-tools wget \
vim-minimal python3.11 python3.11-pip python3.11-wheel \
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all
EOF
else
add_to_docker <<EOF
FROM $docker_base
WORKDIR /app
@ -64,6 +76,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
EOF
fi
# Add pip dependencies first since llama-stack is what will change most often
# so we can reuse layers.

View file

@ -20,7 +20,7 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
@ -161,6 +161,7 @@ a default SQLite store will be used.""",
datasets: List[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
class BuildConfig(BaseModel):

View file

@ -4,11 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from importlib.metadata import version
from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.inspect import HealthInfo, Inspect, ProviderInfo, RouteInfo
from llama_stack.apis.inspect import (
HealthInfo,
Inspect,
ProviderInfo,
RouteInfo,
VersionInfo,
)
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
@ -65,3 +72,6 @@ class DistributionInspectImpl(Inspect):
async def health(self) -> HealthInfo:
return HealthInfo(status="OK")
async def version(self) -> VersionInfo:
return VersionInfo(version=version("llama-stack"))

View file

@ -33,6 +33,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
@ -67,6 +68,7 @@ def stream_across_asyncio_run_boundary(
async_gen_maker,
pool_executor: ThreadPoolExecutor,
path: Optional[str] = None,
provider_data: Optional[dict[str, Any]] = None,
) -> Generator[T, None, None]:
result_queue = queue.Queue()
stop_event = threading.Event()
@ -75,6 +77,10 @@ def stream_across_asyncio_run_boundary(
# make sure we make the generator in the event loop context
gen = await async_gen_maker()
await start_trace(path, {"__location__": "library_client"})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
)
try:
async for item in await gen:
result_queue.put(item)
@ -174,6 +180,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
config_path_or_template_name: str,
skip_logger_removal: bool = False,
custom_provider_registry: Optional[ProviderRegistry] = None,
provider_data: Optional[dict[str, Any]] = None,
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
@ -181,6 +188,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data
def initialize(self):
if in_notebook():
@ -219,10 +227,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
lambda: self.async_client.request(*args, **kwargs),
self.pool_executor,
path=path,
provider_data=self.provider_data,
)
else:
async def _traced_request():
if self.provider_data:
set_request_provider_data(
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
)
await start_trace(path, {"__location__": "library_client"})
try:
return await self.async_client.request(*args, **kwargs)
@ -267,6 +280,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.config, self.custom_provider_registry
)
except ModuleNotFoundError as _e:
cprint(_e.msg, "red")
cprint(
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
"yellow",

View file

@ -40,8 +40,8 @@ class NeedsRequestProviderData:
def set_request_provider_data(headers: Dict[str, str]):
keys = [
"X-LlamaStack-ProviderData",
"x-llamastack-providerdata",
"X-LlamaStack-Provider-Data",
"x-llamastack-provider-data",
]
for key in keys:
val = headers.get(key, None)

View file

@ -5,9 +5,7 @@
# the root directory of this source tree.
import importlib
import inspect
import logging
from typing import Any, Dict, List, Set
from llama_stack.apis.agents import Agents
@ -28,7 +26,6 @@ from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import (
AutoRoutedProviderSpec,
Provider,
@ -38,7 +35,6 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import (
Api,
DatasetsProtocolPrivate,

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
AppEvalTaskConfig,
@ -38,7 +38,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime
from llama_stack.apis.tools import ToolDef, ToolRuntime
from llama_stack.providers.datatypes import RoutingTable
@ -127,7 +127,7 @@ class InferenceRouter(Inference):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
@ -417,7 +417,9 @@ class ToolRuntimeRouter(ToolRuntime):
args=args,
)
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
return await self.routing_table.get_provider_impl(
tool_group.name
).discover_tools(tool_group)
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
tool_group_id, mcp_endpoint
)

View file

@ -6,7 +6,7 @@
from typing import Any, Dict, List, Optional
from pydantic import parse_obj_as
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
@ -26,20 +26,12 @@ from llama_stack.apis.scoring_functions import (
ScoringFunctions,
)
from llama_stack.apis.shields import Shield, Shields
from llama_stack.apis.tools import (
MCPToolGroupDef,
Tool,
ToolGroup,
ToolGroupDef,
ToolGroups,
UserDefinedToolGroupDef,
)
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.distribution.datatypes import (
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
)
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
@ -361,7 +353,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
memory_bank_data["embedding_dimension"] = model.metadata[
"embedding_dimension"
]
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data)
await self.register_object(memory_bank)
return memory_bank
@ -496,54 +488,44 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
tools = await self.get_all_with_type("tool")
if tool_group_id:
tools = [tool for tool in tools if tool.tool_group == tool_group_id]
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
return tools
async def list_tool_groups(self) -> List[ToolGroup]:
return await self.get_all_with_type("tool_group")
async def get_tool_group(self, tool_group_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", tool_group_id)
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return await self.get_object_by_identifier("tool_group", toolgroup_id)
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)
async def register_tool_group(
self,
tool_group_id: str,
tool_group: ToolGroupDef,
provider_id: Optional[str] = None,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
) -> None:
tools = []
tool_defs = []
if provider_id is None:
if len(self.impls_by_provider_id.keys()) > 1:
raise ValueError(
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}"
)
provider_id = list(self.impls_by_provider_id.keys())[0]
if isinstance(tool_group, MCPToolGroupDef):
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
tool_group
)
elif isinstance(tool_group, UserDefinedToolGroupDef):
tool_defs = tool_group.tools
else:
raise ValueError(f"Unknown tool group: {tool_group}")
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
toolgroup_id, mcp_endpoint
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
for tool_def in tool_defs:
tools.append(
Tool(
identifier=tool_def.name,
tool_group=tool_group_id,
description=tool_def.description,
parameters=tool_def.parameters,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
parameters=tool_def.parameters or [],
provider_id=provider_id,
tool_prompt_format=tool_def.tool_prompt_format,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
tool_host=tool_host,
)
)
for tool in tools:
@ -561,9 +543,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
await self.dist_registry.register(
ToolGroup(
identifier=tool_group_id,
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=tool_group_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
)

View file

@ -16,6 +16,8 @@ import traceback
import warnings
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Any, Union
@ -228,6 +230,52 @@ class TracingMiddleware:
await end_trace()
class ClientVersionMiddleware:
def __init__(self, app):
self.app = app
self.server_version = parse_version("llama-stack")
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = dict(scope.get("headers", []))
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version:
try:
client_version_parts = tuple(
map(int, client_version.split(".")[:2])
)
server_version_parts = tuple(
map(int, self.server_version.split(".")[:2])
)
if client_version_parts != server_version_parts:
async def send_version_error(send):
await send(
{
"type": "http.response.start",
"status": 426,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps(
{
"error": {
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please upgrade your client."
}
}
).encode()
await send(
{"type": "http.response.body", "body": error_msg}
)
return await send_version_error(send)
except (ValueError, IndexError):
# If version parsing fails, let the request through
pass
return await self.app(scope, receive, send)
def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
@ -242,7 +290,7 @@ def main():
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMASTACK_PORT", 5000)),
default=int(os.getenv("LLAMA_STACK_PORT", 5000)),
help="Port to listen on",
)
parser.add_argument(
@ -291,6 +339,7 @@ def main():
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
app.add_middleware(ClientVersionMiddleware)
try:
impls = asyncio.run(construct_stack(config))

View file

@ -4,15 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib.resources
import logging
import os
import re
from pathlib import Path
from typing import Any, Dict, Optional
import pkg_resources
import yaml
from termcolor import colored
from llama_stack.apis.agents import Agents
@ -33,14 +31,13 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
LLAMA_STACK_API_VERSION = "alpha"
@ -65,6 +62,8 @@ class LlamaStack(
Models,
Shields,
Inspect,
ToolGroups,
ToolRuntime,
):
pass
@ -81,6 +80,7 @@ RESOURCES = [
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
]
@ -210,14 +210,13 @@ async def construct_stack(
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = pkg_resources.resource_filename(
"llama_stack", f"templates/{template}/run.yaml"
template_path = (
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
)
if not Path(template_path).exists():
raise ValueError(f"Template '{template}' not found at {template_path}")
with open(template_path) as f:
run_config = yaml.safe_load(f)
with importlib.resources.as_file(template_path) as path:
if not path.exists():
raise ValueError(f"Template '{template}' not found at {template_path}")
run_config = yaml.safe_load(path.open())
return StackRunConfig(**replace_env_vars(run_config))

View file

@ -90,6 +90,6 @@ $DOCKER_BINARY run $DOCKER_OPTS -it \
$env_vars \
-v "$yaml_config:/app/config.yaml" \
$mounts \
--env LLAMASTACK_PORT=$port \
--env LLAMA_STACK_PORT=$port \
--entrypoint='["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]' \
$docker_image:$version_tag

View file

@ -12,7 +12,6 @@ import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@ -36,7 +35,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v3"
KEY_VERSION = "v4"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -22,6 +22,8 @@ async def get_provider_impl(
deps[Api.memory],
deps[Api.safety],
deps[Api.memory_banks],
deps[Api.tool_runtime],
deps[Api.tool_groups],
)
await impl.initialize()
return impl

View file

@ -4,8 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import copy
import json
import logging
import os
import re
@ -13,16 +13,16 @@ import secrets
import string
import uuid
from datetime import datetime
from typing import AsyncGenerator, Dict, List, Optional, Tuple
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentToolGroup,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseEvent,
AgentTurnResponseEventType,
@ -33,25 +33,14 @@ from llama_stack.apis.agents import (
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
Attachment,
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
Document,
InferenceStep,
MemoryRetrievalStep,
MemoryToolDefinition,
PhotogenToolDefinition,
SearchToolDefinition,
ShieldCallStep,
StepType,
ToolExecutionStep,
Turn,
WolframAlphaToolDefinition,
)
from llama_stack.apis.common.content_types import (
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.common.content_types import TextContentItem, URL
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
CompletionMessage,
@ -62,32 +51,20 @@ from llama_stack.apis.inference import (
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory import Memory, MemoryBankDocument
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin
from .tools.base import BaseTool
from .tools.builtin import (
CodeInterpreterTool,
interpret_content_as_attachment,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from .tools.safety import SafeTool
log = logging.getLogger(__name__)
@ -98,6 +75,12 @@ def make_random_string(length: int = 8):
)
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "query_memory"
WEB_SEARCH_TOOL = "web_search"
MEMORY_GROUP = "builtin::memory"
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
@ -108,6 +91,8 @@ class ChatAgent(ShieldRunnerMixin):
memory_api: Memory,
memory_banks_api: MemoryBanks,
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
persistence_store: KVStore,
):
self.agent_id = agent_id
@ -118,29 +103,8 @@ class ChatAgent(ShieldRunnerMixin):
self.memory_banks_api = memory_banks_api
self.safety_api = safety_api
self.storage = AgentPersistence(agent_id, persistence_store)
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
tool = WolframAlphaTool(tool_defn.api_key)
elif isinstance(tool_defn, SearchToolDefinition):
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(dump_dir=self.tempdir)
else:
continue
builtin_tools.append(
SafeTool(
tool,
safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
self.tools_dict = {t.get_name(): t for t in builtin_tools}
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
ShieldRunnerMixin.__init__(
self,
@ -228,9 +192,10 @@ class ChatAgent(ShieldRunnerMixin):
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents,
toolgroups_for_turn=request.toolgroups,
):
if isinstance(chunk, CompletionMessage):
log.info(
@ -278,9 +243,10 @@ class ChatAgent(ShieldRunnerMixin):
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -297,7 +263,13 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
session_id, turn_id, input_messages, attachments, sampling_params, stream
session_id,
turn_id,
input_messages,
sampling_params,
stream,
documents,
toolgroups_for_turn,
):
if isinstance(res, bool):
return
@ -353,6 +325,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
@ -373,6 +346,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
@ -388,73 +362,116 @@ class ChatAgent(ShieldRunnerMixin):
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
enabled_tools = set(t.type for t in self.agent_config.tools)
need_rag_context = await self._should_retrieve_context(
input_messages, attachments
)
if need_rag_context:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
toolgroup_args = {}
for toolgroup in self.agent_config.toolgroups:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroup_args[toolgroup.name] = toolgroup.args
if toolgroups_for_turn:
for toolgroup in toolgroups_for_turn:
if isinstance(toolgroup, AgentToolGroupWithArgs):
toolgroup_args[toolgroup.name] = toolgroup.args
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(
session_id, documents, input_messages, tool_defs
)
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
if memory_tool_group is None:
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
with tracing.span(MEMORY_QUERY_TOOL) as span:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
)
query_args = {
"messages": [msg.content for msg in input_messages],
**toolgroup_args.get(memory_tool_group, {}),
}
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
with tracing.span("retrieve_rag_context") as span:
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id:
if "memory_bank_ids" not in query_args:
query_args["memory_bank_ids"] = []
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call_delta=ToolCallDelta(
parse_status=ToolCallParseStatus.success,
content=ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
),
),
)
)
)
result = await self.tool_runtime_api.invoke_tool(
tool_name=MEMORY_QUERY_TOOL,
args=query_args,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[
ToolCall(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
arguments={},
)
],
tool_responses=[
ToolResponse(
call_id="",
tool_name=MEMORY_QUERY_TOOL,
content=result.content,
)
],
),
)
)
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", rag_context)
span.set_attribute("bank_ids", bank_ids)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.memory_retrieval.value,
step_id=step_id,
step_details=MemoryRetrievalStep(
turn_id=turn_id,
step_id=step_id,
memory_bank_ids=bank_ids,
inserted_context=rag_context or "",
),
)
)
)
if rag_context:
last_message = input_messages[-1]
last_message.context = rag_context
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
# TODO: we need to migrate URL away from str type
pattern = re.compile("^(https?://|file://|data:)")
urls += [
URL(uri=a.content) for a in attachments if pattern.match(a.content)
]
msg = await attachment_message(self.tempdir, urls)
input_messages.append(msg)
span.set_attribute("output", result.content)
span.set_attribute("error_code", result.error_code)
span.set_attribute("error_message", result.error_message)
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
if result.error_code == 0:
last_message = input_messages[-1]
last_message.context = result.content
output_attachments = []
n_iter = 0
# Build a map of custom tools to their definitions for faster lookup
client_tools = {}
for tool in self.agent_config.client_tools:
client_tools[tool.name] = tool
while True:
msg = input_messages[-1]
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -473,7 +490,11 @@ class ChatAgent(ShieldRunnerMixin):
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
tools=[
tool
for tool in tool_defs.values()
if tool_to_group.get(tool.tool_name, None) != MEMORY_GROUP
],
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
@ -572,9 +593,9 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list):
message.content += attachments
message.content += output_attachments
else:
message.content = [message.content] + attachments
message.content = [message.content] + output_attachments
yield message
else:
log.info(f"Partial message: {str(message)}")
@ -582,9 +603,7 @@ class ChatAgent(ShieldRunnerMixin):
else:
log.info(f"{str(message)}")
tool_call = message.tool_calls[0]
name = tool_call.tool_name
if not isinstance(name, BuiltinTool) or name not in enabled_tools:
if tool_call.tool_name in client_tools:
yield message
return
@ -607,16 +626,22 @@ class ChatAgent(ShieldRunnerMixin):
)
)
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"tool_name": tool_name,
"input": message.model_dump_json(),
},
) as span:
result_messages = await execute_tool_call_maybe(
self.tools_dict,
self.tool_runtime_api,
session_id,
[message],
toolgroup_args,
tool_to_group,
)
assert (
len(result_messages) == 1
@ -628,6 +653,7 @@ class ChatAgent(ShieldRunnerMixin):
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
@ -647,7 +673,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if out_attachment := interpret_content_as_attachment(
if out_attachment := _interpret_content_as_attachment(
result_message.content
):
# NOTE: when we push this message back to the model, the model may ignore the
@ -659,6 +685,150 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1
async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
# Determine which tools to include
agent_config_toolgroups = set(
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in self.agent_config.toolgroups
)
toolgroups_for_turn_set = (
agent_config_toolgroups
if toolgroups_for_turn is None
else {
(
toolgroup.name
if isinstance(toolgroup, AgentToolGroupWithArgs)
else toolgroup
)
for toolgroup in toolgroups_for_turn
}
)
tool_def_map = {}
tool_to_group = {}
for tool_def in self.agent_config.client_tools:
if tool_def_map.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
tool_def_map[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.name] = "__client_tools__"
for toolgroup_name in agent_config_toolgroups:
if toolgroup_name not in toolgroups_for_turn_set:
continue
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
for tool_def in tools:
if (
toolgroup_name.startswith("builtin")
and toolgroup_name != MEMORY_GROUP
):
tool_name = tool_def.identifier
built_in_type = BuiltinTool.brave_search
if tool_name == "web_search":
built_in_type = BuiltinTool.brave_search
else:
built_in_type = BuiltinTool(tool_name)
if tool_def_map.get(built_in_type, None):
raise ValueError(f"Tool {built_in_type} already exists")
tool_def_map[built_in_type] = ToolDefinition(
tool_name=built_in_type
)
tool_to_group[built_in_type] = tool_def.toolgroup_id
continue
if tool_def_map.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists")
tool_def_map[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
return tool_def_map, tool_to_group
async def handle_documents(
self,
session_id: str,
documents: List[Document],
input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
for d in documents:
if isinstance(d.content, URL):
url_items.append(d.content)
elif pattern.match(d.content):
url_items.append(URL(uri=d.content))
else:
content_items.append(d)
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
if code_interpreter_tool:
for c in content_items:
temp_file_path = os.path.join(
self.tempdir, f"{make_random_string()}.txt"
)
with open(temp_file_path, "w") as temp_file:
temp_file.write(c.content)
url_items.append(URL(uri=f"file://{temp_file_path}"))
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
# Since memory is present, add all the data to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_memory_bank(session_id, documents)
else:
# if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference
# and add it to the last message's context
input_messages[-1].context = "\n".join(
[doc.content for doc in content_items]
+ await load_data_from_urls(url_items)
)
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
@ -679,129 +849,39 @@ class ChatAgent(ShieldRunnerMixin):
return bank_id
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
) -> bool:
enabled_tools = set(t.type for t in self.agent_config.tools)
if attachments:
if (
AgentTool.code_interpreter.value in enabled_tools
and self.agent_config.tool_choice == ToolChoice.required
):
return False
else:
return True
return AgentTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
for t in self.agent_config.tools:
if t.type == AgentTool.memory.value:
return t
return None
async def _retrieve_context(
self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
bank_ids = []
memory = self._memory_tool_definition()
assert memory is not None, "Memory tool not configured"
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
if attachments:
bank_id = await self._ensure_memory_bank(session_id)
bank_ids.append(bank_id)
documents = [
MemoryBankDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for a in attachments
]
with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents)
else:
session_info = await self.storage.get_session_info(session_id)
if session_info.memory_bank_id:
bank_ids.append(session_info.memory_bank_id)
if not bank_ids:
# this can happen if the per-session memory bank is not yet populated
# (i.e., no prior turns uploaded an Attachment)
return None, []
query = await generate_rag_query(
memory.query_generator_config, messages, inference_api=self.inference_api
)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
query=query,
params={
"max_chunks": 5,
},
async def add_to_session_memory_bank(
self, session_id: str, data: List[Document]
) -> None:
bank_id = await self._ensure_memory_bank(session_id)
documents = [
MemoryBankDocument(
document_id=str(uuid.uuid4()),
content=a.content,
mime_type=a.mime_type,
metadata={},
)
for bank_id in bank_ids
for a in data
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
if not chunks:
return None, bank_ids
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
await self.memory_api.insert_documents(
bank_id=bank_id,
documents=documents,
)
tokens = 0
picked = []
for c in chunks[: memory.max_chunks]:
tokens += c.token_count
if tokens > memory.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
return (
concat_interleaved_content(
[
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
),
bank_ids,
)
def _get_tools(self) -> List[ToolDefinition]:
ret = []
for t in self.agent_config.tools:
if isinstance(t, SearchToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
elif isinstance(t, WolframAlphaToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
elif isinstance(t, PhotogenToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
elif isinstance(t, CodeInterpreterToolDefinition):
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
elif isinstance(t, FunctionCallToolDefinition):
ret.append(
ToolDefinition(
tool_name=t.function_name,
description=t.description,
parameters=t.parameters,
)
)
return ret
async def load_data_from_urls(urls: List[URL]) -> List[str]:
data = []
for url in urls:
uri = url.uri
if uri.startswith("file://"):
filepath = uri[len("file://") :]
with open(filepath, "r") as f:
data.append(f.read())
elif uri.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(uri)
resp = r.text
data.append(resp)
return data
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
@ -839,7 +919,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
async def execute_tool_call_maybe(
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
tool_runtime_api: ToolRuntime,
session_id: str,
messages: List[CompletionMessage],
toolgroup_args: Dict[str, Dict[str, Any]],
tool_to_group: Dict[str, str],
) -> List[ToolResponseMessage]:
# While Tools.run interface takes a list of messages,
# All tools currently only run on a single message
@ -851,11 +935,45 @@ async def execute_tool_call_maybe(
tool_call = message.tool_calls[0]
name = tool_call.tool_name
assert isinstance(name, BuiltinTool)
group_name = tool_to_group.get(name, None)
if group_name is None:
raise ValueError(f"Tool {name} not found in any tool group")
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
tool_call_args = tool_call.arguments
tool_call_args.update(toolgroup_args.get(group_name, {}))
if isinstance(name, BuiltinTool):
if name == BuiltinTool.brave_search:
name = WEB_SEARCH_TOOL
else:
name = name.value
name = name.value
result = await tool_runtime_api.invoke_tool(
tool_name=name,
args=dict(
session_id=session_id,
**tool_call_args,
),
)
assert name in tools_dict, f"Tool {name} not found"
tool = tools_dict[name]
result_messages = await tool.run(messages)
return result_messages
return [
ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=result.content,
)
]
def _interpret_content_as_attachment(
content: str,
) -> Optional[Attachment]:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match:
snippet = match.group(1)
data = json.loads(snippet)
return Attachment(
url=URL(uri="file://" + data["filepath"]),
mime_type=data["mimetype"],
)
return None

View file

@ -19,17 +19,17 @@ from llama_stack.apis.agents import (
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
Attachment,
Document,
Session,
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent
@ -47,12 +47,16 @@ class MetaReferenceAgentsImpl(Agents):
memory_api: Memory,
safety_api: Safety,
memory_banks_api: MemoryBanks,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
):
self.config = config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.memory_banks_api = memory_banks_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.in_memory_store = InmemoryKVStoreImpl()
self.tempdir = tempfile.mkdtemp()
@ -112,6 +116,8 @@ class MetaReferenceAgentsImpl(Agents):
safety_api=self.safety_api,
memory_api=self.memory_api,
memory_banks_api=self.memory_banks_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store
if agent_config.enable_session_persistence
@ -141,15 +147,17 @@ class MetaReferenceAgentsImpl(Agents):
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
attachments=attachments,
stream=True,
toolgroups=toolgroups,
documents=documents,
)
if stream:
return self._create_agent_turn_streaming(request)

View file

@ -8,13 +8,11 @@ import json
import logging
import uuid
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)

View file

@ -1,93 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import unittest
from llama_models.llama3.api.datatypes import (
Attachment,
BuiltinTool,
CompletionMessage,
StopReason,
ToolCall,
)
from ..tools.builtin import CodeInterpreterTool
class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase):
async def test_matplotlib(self):
tool = CodeInterpreterTool()
code = """
import matplotlib.pyplot as plt
import numpy as np
x = np.array([1, 1])
y = np.array([0, 10])
plt.plot(x, y)
plt.title('x = 1')
plt.xlabel('x')
plt.ylabel('y')
plt.grid(True)
plt.axvline(x=1, color='r')
plt.show()
"""
message = CompletionMessage(
role="assistant",
content="",
tool_calls=[
ToolCall(
call_id="call_id",
tool_name=BuiltinTool.code_interpreter,
arguments={"code": code},
)
],
stop_reason=StopReason.end_of_message,
)
ret = await tool.run([message])
self.assertEqual(len(ret), 1)
output = ret[0].content
self.assertIsInstance(output, Attachment)
self.assertEqual(output.mime_type, "image/png")
async def test_path_unlink(self):
tool = CodeInterpreterTool()
code = """
import os
from pathlib import Path
import tempfile
dpath = Path(os.environ["MPLCONFIGDIR"])
with open(dpath / "test", "w") as f:
f.write("hello")
Path(dpath / "test").unlink()
print("_OK_")
"""
message = CompletionMessage(
role="assistant",
content="",
tool_calls=[
ToolCall(
call_id="call_id",
tool_name=BuiltinTool.code_interpreter,
arguments={"code": code},
)
],
stop_reason=StopReason.end_of_message,
)
ret = await tool.run([message])
self.assertEqual(len(ret), 1)
output = ret[0].content
self.assertTrue("_OK_" in output)
if __name__ == "__main__":
unittest.main()

View file

@ -4,21 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from typing import AsyncIterator, List, Optional, Union
import pytest
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import (
AgentConfig,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseTurnCompletePayload,
StepType,
)
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseStreamChunk,
CompletionMessage,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
@ -27,13 +32,24 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.memory_banks import BankParams, VectorMemoryBank
from llama_stack.apis.safety import RunShieldResponse
from ..agents import (
AGENT_INSTANCES_BY_ID,
MetaReferenceAgentsImpl,
MetaReferenceInferenceConfig,
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolGroup,
ToolHost,
ToolInvocationResult,
ToolPromptFormat,
)
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.inline.agents.meta_reference.agents import (
MetaReferenceAgentsImpl,
MetaReferenceAgentsImplConfig,
)
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class MockInferenceAPI:
@ -48,10 +64,10 @@ class MockInferenceAPI:
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
if stream:
async def stream_response():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="start",
@ -65,19 +81,7 @@ class MockInferenceAPI:
delta="AI is a fascinating field...",
)
)
# yield ChatCompletionResponseStreamChunk(
# event=ChatCompletionResponseEvent(
# event_type="progress",
# delta=ToolCallDelta(
# content=ToolCall(
# call_id="123",
# tool_name=BuiltinTool.brave_search.value,
# arguments={"query": "AI history"},
# ),
# parse_status="success",
# ),
# )
# )
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="complete",
@ -85,12 +89,17 @@ class MockInferenceAPI:
stop_reason="end_of_turn",
)
)
if stream:
return stream_response()
else:
yield ChatCompletionResponse(
return ChatCompletionResponse(
completion_message=CompletionMessage(
role="assistant", content="Mock response", stop_reason="end_of_turn"
role="assistant",
content="Mock response",
stop_reason="end_of_turn",
),
logprobs=[0.1, 0.2, 0.3] if logprobs else None,
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
)
@ -165,6 +174,98 @@ class MockMemoryAPI:
self.documents[bank_id].pop(doc_id, None)
class MockToolGroupsAPI:
async def register_tool_group(
self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None
) -> None:
pass
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return ToolGroup(
identifier=toolgroup_id,
provider_resource_id=toolgroup_id,
)
async def list_tool_groups(self) -> List[ToolGroup]:
return []
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
if tool_group_id == MEMORY_TOOLGROUP:
return [
Tool(
identifier=MEMORY_QUERY_TOOL,
provider_resource_id=MEMORY_QUERY_TOOL,
toolgroup_id=MEMORY_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::memory",
parameters=[],
)
]
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
return [
Tool(
identifier="code_interpreter",
provider_resource_id="code_interpreter",
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::code_interpreter",
parameters=[],
)
]
return []
async def get_tool(self, tool_name: str) -> Tool:
return Tool(
identifier=tool_name,
provider_resource_id=tool_name,
toolgroup_id="mock_group",
tool_host=ToolHost.client,
description="Mock tool",
provider_id="mock_provider",
parameters=[],
)
async def unregister_tool_group(self, tool_group_id: str) -> None:
pass
class MockToolRuntimeAPI:
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return []
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
return ToolInvocationResult(content={"result": "Mock tool result"})
class MockMemoryBanksAPI:
async def list_memory_banks(self) -> List[MemoryBank]:
return []
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return None
async def register_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank:
return VectorMemoryBank(
identifier=memory_bank_id,
provider_resource_id=provider_memory_bank_id or memory_bank_id,
embedding_model="mock_model",
chunk_size_in_tokens=512,
)
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
pass
@pytest.fixture
def mock_inference_api():
return MockInferenceAPI()
@ -181,64 +282,107 @@ def mock_memory_api():
@pytest.fixture
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
def mock_tool_groups_api():
return MockToolGroupsAPI()
@pytest.fixture
def mock_tool_runtime_api():
return MockToolRuntimeAPI()
@pytest.fixture
def mock_memory_banks_api():
return MockMemoryBanksAPI()
@pytest.fixture
async def get_agents_impl(
mock_inference_api,
mock_safety_api,
mock_memory_api,
mock_memory_banks_api,
mock_tool_runtime_api,
mock_tool_groups_api,
):
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
impl = MetaReferenceAgentsImpl(
config=MetaReferenceInferenceConfig(),
config=MetaReferenceAgentsImplConfig(
persistence_store=SqliteKVStoreConfig(
db_name=sqlite_file.name,
),
),
inference_api=mock_inference_api,
safety_api=mock_safety_api,
memory_api=mock_memory_api,
memory_banks_api=mock_memory_banks_api,
tool_runtime_api=mock_tool_runtime_api,
tool_groups_api=mock_tool_groups_api,
)
await impl.initialize()
return impl
@pytest.fixture
async def get_chat_agent(get_agents_impl):
impl = await get_agents_impl
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
sampling_params=SamplingParams(),
tools=[
# SearchToolDefinition(
# name="brave_search",
# api_key="test_key",
# ),
],
toolgroups=[],
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=[],
output_shields=[],
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
agent = AGENT_INSTANCES_BY_ID[response.agent_id]
return agent
return await impl.get_agent(response.agent_id)
MEMORY_TOOLGROUP = "builtin::memory"
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
@pytest.fixture
async def get_chat_agent_with_tools(get_agents_impl, request):
impl = await get_agents_impl
toolgroups = request.param
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
toolgroups=toolgroups,
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
return await impl.get_agent(response.agent_id)
@pytest.mark.asyncio
async def test_chat_agent_create_session(chat_agent):
session = chat_agent.create_session("Test Session")
assert session.session_name == "Test Session"
assert session.turns == []
assert session.session_id in chat_agent.sessions
@pytest.mark.asyncio
async def test_chat_agent_create_and_execute_turn(chat_agent):
session = chat_agent.create_session("Test Session")
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
chat_agent = await get_chat_agent
session_id = await chat_agent.create_session("Test Session")
request = AgentTurnCreateRequest(
agent_id="random",
session_id=session.session_id,
agent_id=chat_agent.agent_id,
session_id=session_id,
messages=[UserMessage(content="Hello")],
stream=True,
)
responses = []
async for response in chat_agent.create_and_execute_turn(request):
responses.append(response)
print(responses)
assert len(responses) > 0
assert len(responses) == 4 # TurnStart, StepStart, StepComplete, TurnComplete
assert (
len(responses) == 7
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
assert responses[0].event.payload.turn_id is not None
@pytest.mark.asyncio
async def test_run_multiple_shields_wrapper(chat_agent):
async def test_run_multiple_shields_wrapper(get_chat_agent):
chat_agent = await get_chat_agent
messages = [UserMessage(content="Test message")]
shields = ["test_shield"]
@ -254,69 +398,95 @@ async def test_run_multiple_shields_wrapper(chat_agent):
assert len(responses) == 2 # StepStart, StepComplete
assert responses[0].event.payload.step_type.value == "shield_call"
assert not responses[1].event.payload.step_details.response.is_violation
assert not responses[1].event.payload.step_details.violation
@pytest.mark.asyncio
@pytest.mark.skip(reason="Not yet implemented; need to mock out tool execution easily")
async def test_chat_agent_complex_turn(chat_agent):
# Setup
session = chat_agent.create_session("Test Session")
async def test_chat_agent_complex_turn(get_chat_agent):
chat_agent = await get_chat_agent
session_id = await chat_agent.create_session("Test Session")
request = AgentTurnCreateRequest(
agent_id="random",
session_id=session.session_id,
agent_id=chat_agent.agent_id,
session_id=session_id,
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
stream=True,
)
# Execute the turn
responses = []
async for response in chat_agent.create_and_execute_turn(request):
responses.append(response)
# Assertions
assert len(responses) > 0
# Check for the presence of different step types
step_types = [
response.event.payload.step_type
for response in responses
if hasattr(response.event.payload, "step_type")
]
assert "shield_call" in step_types, "Shield call step is missing"
assert "inference" in step_types, "Inference step is missing"
assert "tool_execution" in step_types, "Tool execution step is missing"
assert StepType.shield_call in step_types, "Shield call step is missing"
assert StepType.inference in step_types, "Inference step is missing"
# Check for the presence of start and complete events
event_types = [
response.event.payload.event_type
for response in responses
if hasattr(response.event.payload, "event_type")
]
assert "start" in event_types, "Start event is missing"
assert "complete" in event_types, "Complete event is missing"
assert "turn_start" in event_types, "Start event is missing"
assert "turn_complete" in event_types, "Complete event is missing"
# Check for the presence of tool call
tool_calls = [
response.event.payload.tool_call
for response in responses
if hasattr(response.event.payload, "tool_call")
]
assert any(
tool_call
for tool_call in tool_calls
if tool_call and tool_call.content.get("name") == "memory"
), "Memory tool call is missing"
# Check for the final turn complete event
assert any(
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
for response in responses
), "Turn complete event is missing"
turn_complete_payload = next(
response.event.payload
for response in responses
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
)
turn = turn_complete_payload.turn
assert turn.input_messages == request.messages, "Input messages do not match"
# Verify the turn was added to the session
assert len(session.turns) == 1, "Turn was not added to the session"
assert (
session.turns[0].input_messages == request.messages
), "Input messages do not match"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"toolgroups, expected_memory, expected_code_interpreter",
[
([], False, False), # no tools
([MEMORY_TOOLGROUP], True, False), # memory only
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
],
)
async def test_chat_agent_tools(
get_agents_impl, toolgroups, expected_memory, expected_code_interpreter
):
impl = await get_agents_impl
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
toolgroups=toolgroups,
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
chat_agent = await impl.get_agent(response.agent_id)
tool_defs, _ = await chat_agent._get_tool_defs()
if expected_memory:
assert MEMORY_QUERY_TOOL in tool_defs
if expected_code_interpreter:
assert BuiltinTool.code_interpreter in tool_defs
if expected_memory and expected_code_interpreter:
# override the tools for turn
new_tool_defs, _ = await chat_agent._get_tool_defs(
toolgroups_for_turn=[
AgentToolGroupWithArgs(
name=MEMORY_TOOLGROUP,
args={"memory_banks": ["test_memory_bank"]},
)
]
)
assert MEMORY_QUERY_TOOL in new_tool_defs
assert BuiltinTool.code_interpreter not in new_tool_defs

View file

@ -1,20 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import List
from llama_stack.apis.inference import Message
class BaseTool(ABC):
@abstractmethod
def get_name(self) -> str:
raise NotImplementedError
@abstractmethod
async def run(self, messages: List[Message]) -> List[Message]:
raise NotImplementedError

View file

@ -1,396 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import re
import tempfile
from abc import abstractmethod
from typing import List, Optional
import requests
from .ipython_tool.code_execution import (
CodeExecutionContext,
CodeExecutionRequest,
CodeExecutor,
TOOLS_ATTACHMENT_KEY_REGEX,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from .base import BaseTool
log = logging.getLogger(__name__)
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
if match:
snippet = match.group(1)
data = json.loads(snippet)
return Attachment(
url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
)
return None
class SingleMessageBuiltinTool(BaseTool):
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
assert len(messages) == 1, f"Expected single message, got {len(messages)}"
message = messages[0]
assert len(message.tool_calls) == 1, "Expected a single tool call"
tool_call = messages[0].tool_calls[0]
query = tool_call.arguments["query"]
response: str = await self.run_impl(query)
message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response,
)
return [message]
@abstractmethod
async def run_impl(self, query: str) -> str:
raise NotImplementedError()
class PhotogenTool(SingleMessageBuiltinTool):
def __init__(self, dump_dir: str) -> None:
self.dump_dir = dump_dir
def get_name(self) -> str:
return BuiltinTool.photogen.value
async def run_impl(self, query: str) -> str:
"""
Implement this to give the model an ability to generate images.
Return:
info = {
"filepath": str(image_filepath),
"mimetype": "image/png",
}
"""
raise NotImplementedError()
class SearchTool(SingleMessageBuiltinTool):
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
self.api_key = api_key
self.engine_type = engine
if engine == SearchEngineType.bing:
self.engine = BingSearch(api_key, **kwargs)
elif engine == SearchEngineType.brave:
self.engine = BraveSearch(api_key, **kwargs)
elif engine == SearchEngineType.tavily:
self.engine = TavilySearch(api_key, **kwargs)
else:
raise ValueError(f"Unknown search engine: {engine}")
def get_name(self) -> str:
return BuiltinTool.brave_search.value
async def run_impl(self, query: str) -> str:
return await self.engine.search(query)
class BingSearch:
def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None:
self.api_key = api_key
self.top_k = top_k
async def search(self, query: str) -> str:
url = "https://api.bing.microsoft.com/v7.0/search"
headers = {
"Ocp-Apim-Subscription-Key": self.api_key,
}
params = {
"count": self.top_k,
"textDecorations": True,
"textFormat": "HTML",
"q": query,
}
response = requests.get(url=url, params=params, headers=headers)
response.raise_for_status()
clean = self._clean_response(response.json())
return json.dumps(clean)
def _clean_response(self, search_response):
clean_response = []
query = search_response["queryContext"]["originalQuery"]
if "webPages" in search_response:
pages = search_response["webPages"]["value"]
for p in pages:
selected_keys = {"name", "url", "snippet"}
clean_response.append(
{k: v for k, v in p.items() if k in selected_keys}
)
if "news" in search_response:
clean_news = []
news = search_response["news"]["value"]
for n in news:
selected_keys = {"name", "url", "description"}
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
clean_response.append(clean_news)
return {"query": query, "top_k": clean_response}
class BraveSearch:
def __init__(self, api_key: str) -> None:
self.api_key = api_key
async def search(self, query: str) -> str:
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": self.api_key,
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": query}
response = requests.get(url=url, params=payload, headers=headers)
return json.dumps(self._clean_brave_response(response.json()))
def _clean_brave_response(self, search_response, top_k=3):
query = None
clean_response = []
if "query" in search_response:
if "original" in search_response["query"]:
query = search_response["query"]["original"]
if "mixed" in search_response:
mixed_results = search_response["mixed"]
for m in mixed_results["main"][:top_k]:
r_type = m["type"]
results = search_response[r_type]["results"]
if r_type == "web":
# For web data - add a single output from the search
idx = m["index"]
selected_keys = [
"type",
"title",
"url",
"description",
"date",
"extra_snippets",
]
cleaned = {
k: v for k, v in results[idx].items() if k in selected_keys
}
elif r_type == "faq":
# For faw data - take a list of all the questions & answers
selected_keys = ["type", "question", "answer", "title", "url"]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
elif r_type == "infobox":
idx = m["index"]
selected_keys = [
"type",
"title",
"url",
"description",
"long_desc",
]
cleaned = {
k: v for k, v in results[idx].items() if k in selected_keys
}
elif r_type == "videos":
selected_keys = [
"type",
"url",
"title",
"description",
"date",
]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
elif r_type == "locations":
# For faw data - take a list of all the questions & answers
selected_keys = [
"type",
"title",
"url",
"description",
"coordinates",
"postal_address",
"contact",
"rating",
"distance",
"zoom_level",
]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
elif r_type == "news":
# For faw data - take a list of all the questions & answers
selected_keys = [
"type",
"title",
"url",
"description",
]
cleaned = []
for q in results:
cleaned.append(
{k: v for k, v in q.items() if k in selected_keys}
)
else:
cleaned = []
clean_response.append(cleaned)
return {"query": query, "top_k": clean_response}
class TavilySearch:
def __init__(self, api_key: str) -> None:
self.api_key = api_key
async def search(self, query: str) -> str:
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": self.api_key, "query": query},
)
return json.dumps(self._clean_tavily_response(response.json()))
def _clean_tavily_response(self, search_response, top_k=3):
return {"query": search_response["query"], "top_k": search_response["results"]}
class WolframAlphaTool(SingleMessageBuiltinTool):
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.url = "https://api.wolframalpha.com/v2/query"
def get_name(self) -> str:
return BuiltinTool.wolfram_alpha.value
async def run_impl(self, query: str) -> str:
params = {
"input": query,
"appid": self.api_key,
"format": "plaintext",
"output": "json",
}
response = requests.get(
self.url,
params=params,
)
return json.dumps(self._clean_wolfram_alpha_response(response.json()))
def _clean_wolfram_alpha_response(self, wa_response):
remove = {
"queryresult": [
"datatypes",
"error",
"timedout",
"timedoutpods",
"numpods",
"timing",
"parsetiming",
"parsetimedout",
"recalculate",
"id",
"host",
"server",
"related",
"version",
{
"pods": [
"scanner",
"id",
"error",
"expressiontypes",
"states",
"infos",
"position",
"numsubpods",
]
},
"assumptions",
],
}
for main_key in remove:
for key_to_remove in remove[main_key]:
try:
if key_to_remove == "assumptions":
if "assumptions" in wa_response[main_key]:
del wa_response[main_key][key_to_remove]
if isinstance(key_to_remove, dict):
for sub_key in key_to_remove:
if sub_key == "pods":
for i in range(len(wa_response[main_key][sub_key])):
if (
wa_response[main_key][sub_key][i]["title"]
== "Result"
):
del wa_response[main_key][sub_key][i + 1 :]
break
sub_items = wa_response[main_key][sub_key]
for i in range(len(sub_items)):
for sub_key_to_remove in key_to_remove[sub_key]:
if sub_key_to_remove in sub_items[i]:
del sub_items[i][sub_key_to_remove]
elif key_to_remove in wa_response[main_key]:
del wa_response[main_key][key_to_remove]
except KeyError:
pass
return wa_response
class CodeInterpreterTool(BaseTool):
def __init__(self) -> None:
ctx = CodeExecutionContext(
matplotlib_dump_dir=tempfile.mkdtemp(),
)
self.code_executor = CodeExecutor(ctx)
def get_name(self) -> str:
return BuiltinTool.code_interpreter.value
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
message = messages[0]
assert len(message.tool_calls) == 1, "Expected a single tool call"
tool_call = messages[0].tool_calls[0]
script = tool_call.arguments["code"]
req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req)
pieces = [res["process_status"]]
for out_type in ["stdout", "stderr"]:
res_out = res[out_type]
if res_out != "":
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
if out_type == "stderr":
log.error(f"ipython tool error: ↓\n{res_out}")
message = ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content="\n".join(pieces),
)
return [message]

View file

@ -1,42 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety
from ..safety import ShieldRunnerMixin
from .builtin import BaseTool
class SafeTool(BaseTool, ShieldRunnerMixin):
"""A tool that makes other tools safety enabled"""
def __init__(
self,
tool: BaseTool,
safety_api: Safety,
input_shields: List[str] = None,
output_shields: List[str] = None,
):
self._tool = tool
ShieldRunnerMixin.__init__(
self, safety_api, input_shields=input_shields, output_shields=output_shields
)
def get_name(self) -> str:
return self._tool.get_name()
async def run(self, messages: List[Message]) -> List[Message]:
if self.input_shields:
await self.run_multiple_shields(messages, self.input_shields)
# run the underlying tool
res = await self._tool.run(messages)
if self.output_shields:
await self.run_multiple_shields(messages, self.output_shields)
return res

View file

@ -5,5 +5,14 @@
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
class LocalFSDatasetIOConfig(BaseModel): ...
class LocalFSDatasetIOConfig(BaseModel):
kvstore: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "localfs_datasetio.db").as_posix()
) # Uses SQLite config specific to localfs storage

View file

@ -18,10 +18,14 @@ from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig
DATASETS_PREFIX = "localfs_datasets:"
class BaseDataset(ABC):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -86,8 +90,22 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
self.config = config
# local registry for keeping track of datasets within the provider
self.dataset_infos = {}
self.kvstore = None
async def initialize(self) -> None: ...
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key)
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
)
async def shutdown(self) -> None: ...
@ -95,6 +113,12 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
self,
dataset: Dataset,
) -> None:
# Store in kvstore
key = f"{DATASETS_PREFIX}{dataset.identifier}"
await self.kvstore.set(
key=key,
value=dataset.json(),
)
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
@ -102,6 +126,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
)
async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def get_rows_paginated(

View file

@ -6,7 +6,6 @@
import asyncio
import logging
from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.datatypes import (
@ -37,7 +36,6 @@ from llama_stack.apis.inference import (
ToolCallParseStatus,
ToolChoice,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
@ -262,7 +260,7 @@ class MetaReferenceInferenceImpl(
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:

View file

@ -22,6 +22,7 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from .config import SentenceTransformersInferenceConfig
log = logging.getLogger(__name__)
@ -67,7 +68,7 @@ class SentenceTransformersInferenceImpl(
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:

View file

@ -10,10 +10,8 @@ import uuid
from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
@ -36,7 +34,6 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@ -50,7 +47,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMConfig
log = logging.getLogger(__name__)
@ -67,7 +63,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self):
log.info("Initializing vLLM inference adapter")
log.info("Initializing vLLM inference provider.")
# Disable usage stats reporting. This would be a surprising thing for most
# people to find out was on by default.
@ -95,15 +91,36 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
async def shutdown(self):
"""Shutdown the vLLM inference adapter."""
log.info("Shutting down vLLM inference adapter")
"""Shut down the vLLM inference adapter."""
log.info("Shutting down vLLM inference provider.")
if self.engine:
self.engine.shutdown_background_loop()
async def register_model(self, model: Model) -> None:
raise ValueError(
"You cannot dynamically add a model to a running vllm instance"
)
# Note that the return type of the superclass method is WRONG
async def register_model(self, model: Model) -> Model:
"""
Callback that is called when the server associates an inference endpoint
with an inference provider.
:param model: Object that encapsulates parameters necessary for identifying
a specific LLM.
:returns: The input ``Model`` object. It may or may not be permissible
to change fields before returning this object.
"""
log.info(f"Registering model {model.identifier} with vLLM inference provider.")
# The current version of this provided is hard-coded to serve only
# the model specified in the YAML config file.
configured_model = resolve_model(self.config.model)
registered_model = resolve_model(model.model_id)
if configured_model.core_model_id != registered_model.core_model_id:
raise ValueError(
f"Requested model '{model.identifier}' is different from "
f"model '{self.config.model}' that this provider "
f"is configured to serve"
)
return model
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
if sampling_params is None:
@ -146,7 +163,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -167,7 +184,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = await chat_completion_request_to_prompt(request, self.formatter)
prompt = await chat_completion_request_to_prompt(
request, self.config.model, self.formatter
)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id

View file

@ -16,8 +16,6 @@ import torch
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from llama_stack.apis.post_training import DatasetFormat
from pydantic import BaseModel
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
@ -27,6 +25,8 @@ from torchtune.models.llama3_1 import lora_llama3_1_8b
from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform
from llama_stack.apis.post_training import DatasetFormat
class ModelConfig(BaseModel):
model_definition: Any

View file

@ -14,6 +14,24 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from llama_models.sku_list import resolve_model
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO
@ -41,24 +59,6 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
log = logging.getLogger(__name__)

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import CodeShieldConfig
from .config import CodeScannerConfig
async def get_provider_impl(config: CodeShieldConfig, deps):
async def get_provider_impl(config: CodeScannerConfig, deps):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)

View file

@ -156,7 +156,7 @@ class BraintrustScoringImpl(
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.openai_api_key:
raise ValueError(
'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": <your api key>}'
'Pass OpenAI API Key in the header X-LlamaStack-Provider-Data as { "openai_api_key": <your api key>}'
)
self.config.openai_api_key = provider_data.openai_api_key

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .code_interpreter import CodeInterpreterToolRuntimeImpl
from .config import CodeInterpreterToolConfig
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
impl = CodeInterpreterToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import tempfile
from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
from .config import CodeInterpreterToolConfig
log = logging.getLogger(__name__)
class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: CodeInterpreterToolConfig):
self.config = config
ctx = CodeExecutionContext(
matplotlib_dump_dir=tempfile.mkdtemp(),
)
self.code_executor = CodeExecutor(ctx)
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
pass
async def unregister_tool(self, tool_id: str) -> None:
return
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="code_interpreter",
description="Execute code",
parameters=[
ToolParameter(
name="code",
description="The code to execute",
parameter_type="string",
),
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
script = args["code"]
req = CodeExecutionRequest(scripts=[script])
res = self.code_executor.execute(req)
pieces = [res["process_status"]]
for out_type in ["stdout", "stderr"]:
res_out = res[out_type]
if res_out != "":
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
if out_type == "stderr":
log.error(f"ipython tool error: ↓\n{res_out}")
return ToolInvocationResult(content="\n".join(pieces))

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class CodeInterpreterToolConfig(BaseModel):
pass

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.providers.datatypes import Api
from .config import MemoryToolRuntimeConfig
from .memory import MemoryToolRuntimeImpl
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
impl = MemoryToolRuntimeImpl(
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
)
await impl.initialize()
return impl

View file

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, List, Literal, Union
from pydantic import BaseModel, Field
class _MemoryBankConfigCommon(BaseModel):
bank_id: str
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["vector"] = "vector"
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyvalue"] = "keyvalue"
keys: List[str] # what keys to focus on
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["keyword"] = "keyword"
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
type: Literal["graph"] = "graph"
entities: List[str] # what entities to focus on
MemoryBankConfig = Annotated[
Union[
VectorMemoryBankConfig,
KeyValueMemoryBankConfig,
KeywordMemoryBankConfig,
GraphMemoryBankConfig,
],
Field(discriminator="type"),
]
class MemoryQueryGenerator(Enum):
default = "default"
llm = "llm"
custom = "custom"
class DefaultMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.default.value] = (
MemoryQueryGenerator.default.value
)
sep: str = " "
class LLMMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
model: str
template: str
class CustomMemoryQueryGeneratorConfig(BaseModel):
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
MemoryQueryGeneratorConfig = Annotated[
Union[
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
CustomMemoryQueryGeneratorConfig,
],
Field(discriminator="type"),
]
class MemoryToolConfig(BaseModel):
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
class MemoryToolRuntimeConfig(BaseModel):
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: MemoryQueryGeneratorConfig = Field(
default=DefaultMemoryQueryGeneratorConfig()
)
max_tokens_in_context: int = 4096
max_chunks: int = 5

View file

@ -4,25 +4,29 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from jinja2 import Template
from pydantic import BaseModel
from llama_stack.apis.agents import (
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from .config import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
from llama_stack.apis.inference import Message, UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
async def generate_rag_query(
config: MemoryQueryGeneratorConfig,
messages: List[Message],
messages: List[InterleavedContent],
**kwargs,
):
"""
@ -40,21 +44,26 @@ async def generate_rag_query(
async def default_rag_query_generator(
config: DefaultMemoryQueryGeneratorConfig,
messages: List[Message],
messages: List[InterleavedContent],
**kwargs,
):
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
return config.sep.join(interleaved_content_as_str(m) for m in messages)
async def llm_rag_query_generator(
config: LLMMemoryQueryGeneratorConfig,
messages: List[Message],
messages: List[InterleavedContent],
**kwargs,
):
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
inference_api = kwargs["inference_api"]
m_dict = {"messages": [m.model_dump() for m in messages]}
m_dict = {
"messages": [
message.model_dump() if isinstance(message, BaseModel) else message
for message in messages
]
}
template = Template(config.template)
content = template.render(m_dict)

View file

@ -0,0 +1,146 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import logging
import secrets
import string
from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.tools import (
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
from .context_retriever import generate_rag_query
log = logging.getLogger(__name__)
def make_random_string(length: int = 8):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
)
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(
self,
config: MemoryToolRuntimeConfig,
memory_api: Memory,
memory_banks_api: MemoryBanks,
inference_api: Inference,
):
self.config = config
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.inference_api = inference_api
async def initialize(self):
pass
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="query_memory",
description="Retrieve context from memory",
parameters=[
ToolParameter(
name="messages",
description="The input messages to search for",
parameter_type="array",
),
],
)
]
async def _retrieve_context(
self, input_messages: List[InterleavedContent], bank_ids: List[str]
) -> Optional[List[InterleavedContent]]:
if not bank_ids:
return None
query = await generate_rag_query(
self.config.query_generator_config,
input_messages,
inference_api=self.inference_api,
)
tasks = [
self.memory_api.query_documents(
bank_id=bank_id,
query=query,
params={
"max_chunks": self.config.max_chunks,
},
)
for bank_id in bank_ids
]
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
if not chunks:
return None
# sort by score
chunks, scores = zip(
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
)
tokens = 0
picked = []
for c in chunks[: self.config.max_chunks]:
tokens += c.token_count
if tokens > self.config.max_tokens_in_context:
log.error(
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
)
break
picked.append(f"id:{c.document_id}; content:{c.content}")
return [
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
final_args = tool_group.args or {}
final_args.update(args)
config = MemoryToolConfig()
if tool.metadata and tool.metadata.get("config") is not None:
config = MemoryToolConfig(**tool.metadata["config"])
if "memory_bank_ids" in final_args:
bank_ids = final_args["memory_bank_ids"]
else:
bank_ids = [
bank_config.bank_id for bank_config in config.memory_bank_configs
]
if "messages" not in final_args:
raise ValueError("messages are required")
context = await self._retrieve_context(
final_args["messages"],
bank_ids,
)
if context is None:
context = []
return ToolInvocationResult(
content=concat_interleaved_content(context), error_code=0
)

View file

@ -35,6 +35,8 @@ def available_providers() -> List[ProviderSpec]:
Api.safety,
Api.memory,
Api.memory_banks,
Api.tool_runtime,
Api.tool_groups,
],
),
remote_provider_spec(

View file

@ -19,11 +19,58 @@ def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::brave-search",
provider_type="inline::memory-runtime",
pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.brave_search",
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
module="llama_stack.providers.inline.tool_runtime.memory",
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
),
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::code-interpreter",
pip_packages=[],
module="llama_stack.providers.inline.tool_runtime.code_interpreter",
config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig",
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="brave-search",
module="llama_stack.providers.remote.tool_runtime.brave_search",
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
),
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="bing-search",
module="llama_stack.providers.remote.tool_runtime.bing_search",
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
),
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="tavily-search",
module="llama_stack.providers.remote.tool_runtime.tavily_search",
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
),
),
remote_provider_spec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="wolfram-alpha",
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
),
),
remote_provider_spec(
api=Api.tool_runtime,

View file

@ -10,7 +10,6 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
@ -30,7 +29,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
@ -47,7 +45,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
MODEL_ALIASES = [
build_model_alias(
"meta.llama3-1-8b-instruct-v1:0",
@ -101,7 +98,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[

View file

@ -7,11 +7,8 @@
from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
@ -29,7 +26,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
@ -48,7 +44,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import CerebrasImplConfig
model_aliases = [
build_model_alias(
"llama3.1-8b",
@ -130,7 +125,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,

View file

@ -7,11 +7,8 @@
from typing import AsyncGenerator, List, Optional
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent
@ -28,7 +25,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
@ -44,7 +40,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import DatabricksImplConfig
model_aliases = [
build_model_alias(
"databricks-meta-llama-3-1-70b-instruct",
@ -91,7 +86,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:

View file

@ -22,7 +22,7 @@ class FireworksImplConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "https://api.fireworks.ai/inference/v1",
"api_key": "${env.FIREWORKS_API_KEY}",

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
@ -52,46 +51,45 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
MODEL_ALIASES = [
build_model_alias(
"fireworks/llama-v3p1-8b-instruct",
"accounts/fireworks/models/llama-v3p1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p1-70b-instruct",
"accounts/fireworks/models/llama-v3p1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p1-405b-instruct",
"accounts/fireworks/models/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p2-1b-instruct",
"accounts/fireworks/models/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p2-3b-instruct",
"accounts/fireworks/models/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p2-11b-vision-instruct",
"accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p2-90b-vision-instruct",
"accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias(
"fireworks/llama-v3p3-70b-instruct",
"accounts/fireworks/models/llama-v3p3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_model_alias(
"fireworks/llama-guard-3-8b",
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_model_alias(
"fireworks/llama-guard-3-11b-vision",
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
@ -118,7 +116,7 @@ class FireworksInferenceAdapter(
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
'Pass Fireworks API Key in the header X-LlamaStack-Provider-Data as { "fireworks_api_key": <your api key>}'
)
return provider_data.fireworks_api_key
@ -198,7 +196,7 @@ class FireworksInferenceAdapter(
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,

View file

@ -7,6 +7,7 @@
import warnings
from typing import AsyncIterator, List, Optional, Union
import groq
from groq import Groq
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
@ -33,6 +34,7 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_alias_with_just_provider_model_id,
ModelRegistryHelper,
)
from .groq_utils import (
convert_chat_completion_request,
convert_chat_completion_response,
@ -94,9 +96,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[
ToolPromptFormat
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
@ -124,7 +124,16 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
)
)
response = self._get_client().chat.completions.create(**request)
try:
response = self._get_client().chat.completions.create(**request)
except groq.BadRequestError as e:
if e.body.get("error", {}).get("code") == "tool_use_failed":
# For smaller models, Groq may fail to call a tool even when the request is well formed
raise ValueError(
"Groq failed to call a tool", e.body.get("error", {})
) from e
else:
raise e
if stream:
return convert_chat_completion_response_stream(response)
@ -145,6 +154,6 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.groq_api_key:
raise ValueError(
'Pass Groq API Key in the header X-LlamaStack-ProviderData as { "groq_api_key": "<your api key>" }'
'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "<your api key>" }'
)
return Groq(api_key=provider_data.groq_api_key)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import warnings
from typing import AsyncGenerator, Literal
@ -14,14 +15,20 @@ from groq.types.chat.chat_completion_assistant_message_param import (
)
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from groq.types.chat.chat_completion_system_message_param import (
ChatCompletionSystemMessageParam,
)
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from groq.types.chat.chat_completion_user_message_param import (
ChatCompletionUserMessageParam,
)
from groq.types.chat.completion_create_params import CompletionCreateParams
from groq.types.shared.function_definition import FunctionDefinition
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.inference import (
ChatCompletionRequest,
@ -32,6 +39,11 @@ from llama_stack.apis.inference import (
CompletionMessage,
Message,
StopReason,
ToolCall,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolPromptFormat,
)
@ -59,8 +71,8 @@ def convert_chat_completion_request(
# so we exclude it for now
warnings.warn("repetition_penalty is not supported")
if request.tools:
warnings.warn("tools are not supported yet")
if request.tool_prompt_format != ToolPromptFormat.json:
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
return CompletionCreateParams(
model=request.model,
@ -71,6 +83,8 @@ def convert_chat_completion_request(
max_tokens=request.sampling_params.max_tokens or None,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
tool_choice=request.tool_choice.value if request.tool_choice else None,
)
@ -87,17 +101,64 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam:
raise ValueError(f"Invalid message role: {message.role}")
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
# Groq requires a description for function tools
if tool_definition.description is None:
raise AssertionError("tool_definition.description is required")
tool_parameters = tool_definition.parameters or {}
return ChatCompletionToolParam(
type="function",
function=FunctionDefinition(
name=tool_definition.tool_name,
description=tool_definition.description,
parameters={
key: _convert_groq_tool_parameter(param)
for key, param in tool_parameters.items()
},
),
)
def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict:
param = {
"type": tool_parameter.param_type,
}
if tool_parameter.description is not None:
param["description"] = tool_parameter.description
if tool_parameter.required is not None:
param["required"] = tool_parameter.required
if tool_parameter.default is not None:
param["default"] = tool_parameter.default
return param
def convert_chat_completion_response(
response: ChatCompletion,
) -> ChatCompletionResponse:
# groq only supports n=1 at time of writing, so there is only one choice
choice = response.choices[0]
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
),
)
if choice.finish_reason == "tool_calls":
tool_calls = [
_convert_groq_tool_call(tool_call)
for tool_call in choice.message.tool_calls
]
return ChatCompletionResponse(
completion_message=CompletionMessage(
tool_calls=tool_calls,
stop_reason=StopReason.end_of_message,
# Content is not optional
content="",
),
logprobs=None,
)
else:
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
),
)
def _map_finish_reason_to_stop_reason(
@ -116,7 +177,7 @@ def _map_finish_reason_to_stop_reason(
elif finish_reason == "length":
return StopReason.out_of_tokens
elif finish_reason == "tool_calls":
raise NotImplementedError("tool_calls is not supported yet")
return StopReason.end_of_message
else:
raise ValueError(f"Invalid finish reason: {finish_reason}")
@ -129,25 +190,50 @@ async def convert_chat_completion_response_stream(
for chunk in stream:
choice = chunk.choices[0]
# We assume there's only one finish_reason for the entire stream.
# We collect the last finish_reason
if choice.finish_reason:
stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=choice.delta.content or "",
logprobs=None,
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=choice.delta.content or "",
logprobs=None,
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
)
)
elif choice.delta.tool_calls:
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
if len(choice.delta.tool_calls) > 1:
warnings.warn(
"Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest."
)
# We assume Groq produces fully formed tool calls for each chunk
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=choice.delta.content or "",
logprobs=None,
)
)
)
event_type = ChatCompletionResponseEventType.progress
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
logprobs=None,
stop_reason=stop_reason,
)
def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
return ToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
# Note that Groq may return a string that is not valid JSON here
# So this may raise a 500 error. Going to leave this as is to see
# how big of an issue this is and what we can do about it.
arguments=json.loads(tool_call.function.arguments),
)

View file

@ -175,9 +175,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[
ToolPromptFormat
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[

View file

@ -144,7 +144,7 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
message = UserMessage(**message)
elif message["role"] == "assistant":
message = CompletionMessage(**message)
elif message["role"] == "ipython":
elif message["role"] == "tool":
message = ToolResponseMessage(**message)
elif message["role"] == "system":
message = SystemMessage(**message)

View file

@ -9,7 +9,6 @@ from typing import AsyncGenerator, List, Optional, Union
import httpx
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
@ -35,7 +34,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
build_model_alias_with_just_provider_model_id,
@ -222,7 +220,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:

View file

@ -30,13 +30,11 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
@ -205,7 +203,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,

View file

@ -7,11 +7,8 @@
from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from together import Together
from llama_stack.apis.common.content_types import InterleavedContent
@ -53,7 +50,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
MODEL_ALIASES = [
build_model_alias(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
@ -79,6 +75,10 @@ MODEL_ALIASES = [
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_alias(
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
build_model_alias(
"meta-llama/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
@ -135,7 +135,7 @@ class TogetherInferenceAdapter(
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
return Together(api_key=together_api_key)
@ -184,7 +184,7 @@ class TogetherInferenceAdapter(
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,

View file

@ -10,7 +10,6 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent
@ -33,7 +32,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
@ -54,7 +52,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
log = logging.getLogger(__name__)
@ -105,7 +102,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .bing_search import BingSearchToolRuntimeImpl
from .config import BingSearchToolConfig
__all__ = ["BingSearchToolConfig", "BingSearchToolRuntimeImpl"]
from pydantic import BaseModel
class BingSearchToolProviderDataValidator(BaseModel):
bing_search_api_key: str
async def get_adapter_impl(config: BingSearchToolConfig, _deps):
impl = BingSearchToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
import requests
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BingSearchToolConfig
class BingSearchToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: BingSearchToolConfig):
self.config = config
self.url = "https://api.bing.microsoft.com/v7.0/search"
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
pass
async def unregister_tool(self, tool_id: str) -> None:
return
def _get_api_key(self) -> str:
if self.config.api_key:
return self.config.api_key
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.bing_search_api_key:
raise ValueError(
'Pass Bing Search API Key in the header X-LlamaStack-Provider-Data as { "bing_search_api_key": <your api key>}'
)
return provider_data.bing_search_api_key
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web using Bing Search API",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
headers = {
"Ocp-Apim-Subscription-Key": api_key,
}
params = {
"count": self.config.top_k,
"textDecorations": True,
"textFormat": "HTML",
"q": args["query"],
}
response = requests.get(
url=self.url,
params=params,
headers=headers,
)
response.raise_for_status()
return ToolInvocationResult(
content=json.dumps(self._clean_response(response.json()))
)
def _clean_response(self, search_response):
clean_response = []
query = search_response["queryContext"]["originalQuery"]
if "webPages" in search_response:
pages = search_response["webPages"]["value"]
for p in pages:
selected_keys = {"name", "url", "snippet"}
clean_response.append(
{k: v for k, v in p.items() if k in selected_keys}
)
if "news" in search_response:
clean_news = []
news = search_response["news"]["value"]
for n in news:
selected_keys = {"name", "url", "description"}
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
clean_response.append(clean_news)
return {"query": query, "top_k": clean_response}

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
class BingSearchToolConfig(BaseModel):
"""Configuration for Bing Search Tool Runtime"""
api_key: Optional[str] = None
top_k: int = 3

View file

@ -11,10 +11,10 @@ from .config import BraveSearchToolConfig
class BraveSearchToolProviderDataValidator(BaseModel):
api_key: str
brave_search_api_key: str
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
async def get_adapter_impl(config: BraveSearchToolConfig, _deps):
impl = BraveSearchToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -4,11 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import requests
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
@ -25,8 +33,7 @@ class BraveSearchToolRuntimeImpl(
pass
async def register_tool(self, tool: Tool):
if tool.identifier != "brave_search":
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
pass
async def unregister_tool(self, tool_id: str) -> None:
return
@ -36,14 +43,29 @@ class BraveSearchToolRuntimeImpl(
return self.config.api_key
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:
if provider_data is None or not provider_data.brave_search_api_key:
raise ValueError(
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "brave_search_api_key": <your api key>}'
)
return provider_data.api_key
return provider_data.brave_search_api_key
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
raise NotImplementedError("Brave search tool group not supported")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
built_in_type=BuiltinTool.brave_search,
)
]
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
@ -18,3 +18,10 @@ class BraveSearchToolConfig(BaseModel):
default=3,
description="The maximum number of results to return",
)
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"api_key": "${env.BRAVE_SEARCH_API_KEY:}",
"max_results": 3,
}

View file

@ -4,22 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from mcp import ClientSession
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
MCPToolGroupDef,
ToolDef,
ToolGroupDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from mcp import ClientSession
from mcp.client.sse import sse_client
from .config import ModelContextProtocolConfig
@ -30,12 +29,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def initialize(self):
pass
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
if not isinstance(tool_group, MCPToolGroupDef):
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
tools = []
async with sse_client(tool_group.endpoint.uri) as streams:
async with sse_client(mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
@ -57,7 +58,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
description=tool.description,
parameters=parameters,
metadata={
"endpoint": tool_group.endpoint.uri,
"endpoint": mcp_endpoint.uri,
},
)
)

Some files were not shown because too many files have changed in this diff Show more