mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
Merge remote-tracking branch 'origin/main' into support_more_data_format
This commit is contained in:
commit
a3b1c3438b
171 changed files with 14529 additions and 5612 deletions
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
@ -2,4 +2,4 @@
|
||||||
|
|
||||||
# These owners will be the default owners for everything in
|
# These owners will be the default owners for everything in
|
||||||
# the repo. Unless a later match takes precedence,
|
# the repo. Unless a later match takes precedence,
|
||||||
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic
|
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721
|
||||||
|
|
232
.github/workflows/publish-to-test-pypi.yml
vendored
Normal file
232
.github/workflows/publish-to-test-pypi.yml
vendored
Normal file
|
@ -0,0 +1,232 @@
|
||||||
|
name: Publish Python 🐍 distribution 📦 to TestPyPI
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: # Keep manual trigger
|
||||||
|
inputs:
|
||||||
|
version:
|
||||||
|
description: 'Version number (e.g. 0.0.63.dev20250111)'
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
schedule:
|
||||||
|
- cron: "0 0 * * *" # Run every day at midnight
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
trigger-client-and-models-build:
|
||||||
|
name: Trigger llama-stack-client and llama-models build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
version: ${{ steps.version.outputs.version }}
|
||||||
|
client_run_id: ${{ steps.trigger-client.outputs.workflow_id }}
|
||||||
|
model_run_id: ${{ steps.trigger-models.outputs.workflow_id }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
- name: Get date
|
||||||
|
id: date
|
||||||
|
run: echo "date=$(date +'%Y%m%d')" >> $GITHUB_OUTPUT
|
||||||
|
- name: Compute version based on dispatch event
|
||||||
|
id: version
|
||||||
|
run: |
|
||||||
|
# Read base version from pyproject.toml
|
||||||
|
version=$(sed -n 's/.*version="\([^"]*\)".*/\1/p' setup.py)
|
||||||
|
if [ "${{ github.event_name }}" = "schedule" ]; then
|
||||||
|
echo "version=${version}.dev${{ steps.date.outputs.date }}" >> $GITHUB_OUTPUT
|
||||||
|
elif [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||||
|
echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "version=${version}.dev$(shuf -i 10000000-99999999 -n 1)" >> $GITHUB_OUTPUT
|
||||||
|
fi
|
||||||
|
- name: Trigger llama-stack-client workflow
|
||||||
|
id: trigger-client
|
||||||
|
run: |
|
||||||
|
response=$(curl -X POST https://api.github.com/repos/meta-llama/llama-stack-client-python/dispatches \
|
||||||
|
-H 'Accept: application/vnd.github.everest-preview+json' \
|
||||||
|
-H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||||
|
--data "{\"event_type\": \"build-client-package\", \"client_payload\": {\"source\": \"llama-stack-nightly\", \"version\": \"${{ steps.version.outputs.version }}\"}}" \
|
||||||
|
-w "\n%{http_code}")
|
||||||
|
|
||||||
|
http_code=$(echo "$response" | tail -n1)
|
||||||
|
if [ "$http_code" != "204" ]; then
|
||||||
|
echo "Failed to trigger client workflow"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Get the run ID of the triggered workflow
|
||||||
|
sleep 5 # Wait for workflow to be created
|
||||||
|
run_id=$(curl -s -H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||||
|
"https://api.github.com/repos/meta-llama/llama-stack-client-python/actions/runs?event=repository_dispatch" \
|
||||||
|
| jq '.workflow_runs[0].id')
|
||||||
|
echo "workflow_id=$run_id" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Trigger llama-models workflow
|
||||||
|
id: trigger-models
|
||||||
|
run: |
|
||||||
|
response=$(curl -X POST https://api.github.com/repos/meta-llama/llama-models/dispatches \
|
||||||
|
-H 'Accept: application/vnd.github.everest-preview+json' \
|
||||||
|
-H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||||
|
--data "{\"event_type\": \"build-models-package\", \"client_payload\": {\"source\": \"llama-stack-nightly\", \"version\": \"${{ steps.version.outputs.version }}\"}}" \
|
||||||
|
-w "\n%{http_code}")
|
||||||
|
|
||||||
|
http_code=$(echo "$response" | tail -n1)
|
||||||
|
if [ "$http_code" != "204" ]; then
|
||||||
|
echo "Failed to trigger models workflow"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Get the run ID of the triggered workflow
|
||||||
|
sleep 5 # Wait for workflow to be created
|
||||||
|
run_id=$(curl -s -H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||||
|
"https://api.github.com/repos/meta-llama/llama-models/actions/runs?event=repository_dispatch" \
|
||||||
|
| jq '.workflow_runs[0].id')
|
||||||
|
echo "workflow_id=$run_id" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
wait-for-workflows:
|
||||||
|
name: Wait for triggered workflows
|
||||||
|
needs: trigger-client-and-models-build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Wait for client workflow
|
||||||
|
run: |
|
||||||
|
while true; do
|
||||||
|
status=$(curl -s -H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||||
|
"https://api.github.com/repos/meta-llama/llama-stack-client-python/actions/runs/${{ needs.trigger-client-and-models-build.outputs.client_run_id }}" \
|
||||||
|
| jq -r '.status')
|
||||||
|
conclusion=$(curl -s -H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||||
|
"https://api.github.com/repos/meta-llama/llama-stack-client-python/actions/runs/${{ needs.trigger-client-and-models-build.outputs.client_run_id }}" \
|
||||||
|
| jq -r '.conclusion')
|
||||||
|
|
||||||
|
echo "llama-stack-client-python workflow status: $status, conclusion: $conclusion"
|
||||||
|
|
||||||
|
if [ "$status" = "completed" ]; then
|
||||||
|
if [ "$conclusion" != "success" ]; then
|
||||||
|
echo "llama-stack-client-python workflow failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
|
||||||
|
sleep 10
|
||||||
|
done
|
||||||
|
|
||||||
|
- name: Wait for models workflow
|
||||||
|
run: |
|
||||||
|
while true; do
|
||||||
|
status=$(curl -s -H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||||
|
"https://api.github.com/repos/meta-llama/llama-models/actions/runs/${{ needs.trigger-client-and-models-build.outputs.model_run_id }}" \
|
||||||
|
| jq -r '.status')
|
||||||
|
conclusion=$(curl -s -H "authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||||
|
"https://api.github.com/repos/meta-llama/llama-models/actions/runs/${{ needs.trigger-client-and-models-build.outputs.model_run_id }}" \
|
||||||
|
| jq -r '.conclusion')
|
||||||
|
|
||||||
|
echo "llama-models workflow status: $status, conclusion: $conclusion"
|
||||||
|
|
||||||
|
if [ "$status" = "completed" ]; then
|
||||||
|
if [ "$conclusion" != "success" ]; then
|
||||||
|
echo "llama-models workflow failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
|
||||||
|
sleep 10
|
||||||
|
done
|
||||||
|
|
||||||
|
build:
|
||||||
|
name: Build distribution 📦
|
||||||
|
needs:
|
||||||
|
- wait-for-workflows
|
||||||
|
- trigger-client-and-models-build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
- name: Get date
|
||||||
|
id: date
|
||||||
|
run: echo "date=$(date +'%Y%m%d')" >> $GITHUB_OUTPUT
|
||||||
|
- name: Update version for nightly
|
||||||
|
run: |
|
||||||
|
sed -i 's/version="\([^"]*\)"/version="${{ needs.trigger-client-and-models-build.outputs.version }}"/' setup.py
|
||||||
|
sed -i 's/llama-stack-client>=\([^"]*\)/llama-stack-client==${{ needs.trigger-client-and-models-build.outputs.version }}/' requirements.txt
|
||||||
|
sed -i 's/llama-models>=\([^"]*\)/llama-models==${{ needs.trigger-client-and-models-build.outputs.version }}/' requirements.txt
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
- name: Install pypa/build
|
||||||
|
run: >-
|
||||||
|
python3 -m
|
||||||
|
pip install
|
||||||
|
build
|
||||||
|
--user
|
||||||
|
- name: Build a binary wheel and a source tarball
|
||||||
|
run: python3 -m build
|
||||||
|
- name: Store the distribution packages
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: python-package-distributions
|
||||||
|
path: dist/
|
||||||
|
|
||||||
|
publish-to-testpypi:
|
||||||
|
name: Publish Python 🐍 distribution 📦 to TestPyPI
|
||||||
|
needs:
|
||||||
|
- build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
environment:
|
||||||
|
name: testrelease
|
||||||
|
url: https://test.pypi.org/p/llama-stack
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
id-token: write # IMPORTANT: mandatory for trusted publishing
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Download all the dists
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
name: python-package-distributions
|
||||||
|
path: dist/
|
||||||
|
- name: Publish distribution 📦 to TestPyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
repository-url: https://test.pypi.org/legacy/
|
||||||
|
|
||||||
|
test-published-package:
|
||||||
|
name: Test published package
|
||||||
|
needs:
|
||||||
|
- publish-to-testpypi
|
||||||
|
- trigger-client-and-models-build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Install the package
|
||||||
|
run: |
|
||||||
|
max_attempts=6
|
||||||
|
attempt=1
|
||||||
|
while [ $attempt -le $max_attempts ]; do
|
||||||
|
echo "Attempt $attempt of $max_attempts to install package..."
|
||||||
|
if pip install --no-cache --index-url https://pypi.org/simple/ --extra-index-url https://test.pypi.org/simple/ llama-stack==${{ needs.trigger-client-and-models-build.outputs.version }}; then
|
||||||
|
echo "Package installed successfully"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
if [ $attempt -ge $max_attempts ]; then
|
||||||
|
echo "Failed to install package after $max_attempts attempts"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
attempt=$((attempt + 1))
|
||||||
|
sleep 10
|
||||||
|
done
|
||||||
|
- name: Test the package versions
|
||||||
|
run: |
|
||||||
|
pip list | grep llama_
|
||||||
|
- name: Test CLI commands
|
||||||
|
run: |
|
||||||
|
llama model list
|
||||||
|
llama stack build --list-templates
|
||||||
|
llama model prompt-format -m Llama3.2-11B-Vision-Instruct
|
||||||
|
llama stack list-apis
|
||||||
|
llama stack list-providers inference
|
||||||
|
llama stack list-providers telemetry
|
||||||
|
|
||||||
|
# TODO: add trigger for integration test workflow & docker builds
|
|
@ -99,7 +99,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
|
||||||
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
|
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
|
||||||
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
||||||
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
||||||
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
|
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) |
|
||||||
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
||||||
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) |
|
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) |
|
||||||
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) |
|
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) |
|
||||||
|
|
|
@ -1,9 +1,42 @@
|
||||||
{
|
{
|
||||||
"bedrock": [
|
"hf-serverless": [
|
||||||
|
"aiohttp",
|
||||||
|
"aiosqlite",
|
||||||
|
"autoevals",
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"chromadb-client",
|
||||||
|
"datasets",
|
||||||
|
"faiss-cpu",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"httpx",
|
||||||
|
"huggingface_hub",
|
||||||
|
"matplotlib",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"openai",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pypdf",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"uvicorn",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
|
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||||
|
],
|
||||||
|
"together": [
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"autoevals",
|
"autoevals",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
"boto3",
|
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
@ -22,6 +55,71 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"together",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"uvicorn",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
|
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||||
|
],
|
||||||
|
"vllm-gpu": [
|
||||||
|
"aiosqlite",
|
||||||
|
"autoevals",
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"chromadb-client",
|
||||||
|
"datasets",
|
||||||
|
"faiss-cpu",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"httpx",
|
||||||
|
"matplotlib",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"openai",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pypdf",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"uvicorn",
|
||||||
|
"vllm",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
|
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||||
|
],
|
||||||
|
"remote-vllm": [
|
||||||
|
"aiosqlite",
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"chromadb-client",
|
||||||
|
"faiss-cpu",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"httpx",
|
||||||
|
"matplotlib",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"openai",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pypdf",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -54,6 +152,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -63,7 +162,7 @@
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
],
|
||||||
"hf-endpoint": [
|
"tgi": [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"autoevals",
|
"autoevals",
|
||||||
|
@ -87,6 +186,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -96,11 +196,11 @@
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
],
|
||||||
"hf-serverless": [
|
"bedrock": [
|
||||||
"aiohttp",
|
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"autoevals",
|
"autoevals",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
|
"boto3",
|
||||||
"chardet",
|
"chardet",
|
||||||
"chromadb-client",
|
"chromadb-client",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
@ -108,7 +208,6 @@
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"huggingface_hub",
|
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"nltk",
|
"nltk",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
@ -120,6 +219,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -154,6 +254,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
|
@ -193,6 +294,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
|
@ -207,6 +309,35 @@
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
],
|
||||||
|
"cerebras": [
|
||||||
|
"aiosqlite",
|
||||||
|
"blobfile",
|
||||||
|
"cerebras_cloud_sdk",
|
||||||
|
"chardet",
|
||||||
|
"faiss-cpu",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"httpx",
|
||||||
|
"matplotlib",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pypdf",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"uvicorn",
|
||||||
|
"sentence-transformers --no-deps",
|
||||||
|
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||||
|
],
|
||||||
"ollama": [
|
"ollama": [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
|
@ -231,6 +362,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
@ -240,7 +372,7 @@
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
"torch --index-url https://download.pytorch.org/whl/cpu"
|
||||||
],
|
],
|
||||||
"tgi": [
|
"hf-endpoint": [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"autoevals",
|
"autoevals",
|
||||||
|
@ -264,127 +396,7 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"redis",
|
"redis",
|
||||||
"scikit-learn",
|
"requests",
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"together": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"matplotlib",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pypdf",
|
|
||||||
"redis",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"together",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"remote-vllm": [
|
|
||||||
"aiosqlite",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"matplotlib",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pypdf",
|
|
||||||
"redis",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"uvicorn",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"vllm-gpu": [
|
|
||||||
"aiosqlite",
|
|
||||||
"autoevals",
|
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
|
||||||
"chromadb-client",
|
|
||||||
"datasets",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"matplotlib",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"openai",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pypdf",
|
|
||||||
"redis",
|
|
||||||
"scikit-learn",
|
|
||||||
"scipy",
|
|
||||||
"sentencepiece",
|
|
||||||
"tqdm",
|
|
||||||
"transformers",
|
|
||||||
"uvicorn",
|
|
||||||
"vllm",
|
|
||||||
"sentence-transformers --no-deps",
|
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu"
|
|
||||||
],
|
|
||||||
"cerebras": [
|
|
||||||
"aiosqlite",
|
|
||||||
"blobfile",
|
|
||||||
"cerebras_cloud_sdk",
|
|
||||||
"chardet",
|
|
||||||
"faiss-cpu",
|
|
||||||
"fastapi",
|
|
||||||
"fire",
|
|
||||||
"httpx",
|
|
||||||
"matplotlib",
|
|
||||||
"nltk",
|
|
||||||
"numpy",
|
|
||||||
"opentelemetry-exporter-otlp-proto-http",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"pandas",
|
|
||||||
"pillow",
|
|
||||||
"psycopg2-binary",
|
|
||||||
"pypdf",
|
|
||||||
"redis",
|
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
|
|
|
@ -85,7 +85,7 @@ services:
|
||||||
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
|
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
|
||||||
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
|
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
|
||||||
ports:
|
ports:
|
||||||
- "${LLAMASTACK_PORT:-5001}:${LLAMASTACK_PORT:-5001}"
|
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
|
||||||
# Hack: wait for vLLM server to start before starting docker
|
# Hack: wait for vLLM server to start before starting docker
|
||||||
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 5001"
|
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 5001"
|
||||||
deploy:
|
deploy:
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -486,13 +486,22 @@ class Generator:
|
||||||
parameters = path_parameters + query_parameters
|
parameters = path_parameters + query_parameters
|
||||||
parameters += [
|
parameters += [
|
||||||
Parameter(
|
Parameter(
|
||||||
name="X-LlamaStack-ProviderData",
|
name="X-LlamaStack-Provider-Data",
|
||||||
in_=ParameterLocation.Header,
|
in_=ParameterLocation.Header,
|
||||||
description="JSON-encoded provider data which will be made available to the adapter servicing the API",
|
description="JSON-encoded provider data which will be made available to the adapter servicing the API",
|
||||||
required=False,
|
required=False,
|
||||||
schema=self.schema_builder.classdef_to_ref(str),
|
schema=self.schema_builder.classdef_to_ref(str),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
parameters += [
|
||||||
|
Parameter(
|
||||||
|
name="X-LlamaStack-Client-Version",
|
||||||
|
in_=ParameterLocation.Header,
|
||||||
|
description="Version of the client making the request. This is used to ensure that the client and server are compatible.",
|
||||||
|
required=False,
|
||||||
|
schema=self.schema_builder.classdef_to_ref(str),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
# data passed in payload
|
# data passed in payload
|
||||||
if op.request_params:
|
if op.request_params:
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -8,10 +8,6 @@ building_distro
|
||||||
configuration
|
configuration
|
||||||
```
|
```
|
||||||
|
|
||||||
<!-- self_hosted_distro/index -->
|
|
||||||
<!-- remote_hosted_distro/index -->
|
|
||||||
<!-- ondevice_distro/index -->
|
|
||||||
|
|
||||||
You can instantiate a Llama Stack in one of the following ways:
|
You can instantiate a Llama Stack in one of the following ways:
|
||||||
- **As a Library**: this is the simplest, especially if you are using an external inference service. See [Using Llama Stack as a Library](importing_as_library)
|
- **As a Library**: this is the simplest, especially if you are using an external inference service. See [Using Llama Stack as a Library](importing_as_library)
|
||||||
- **Docker**: we provide a number of pre-built Docker containers so you can start a Llama Stack server instantly. You can also build your own custom Docker container.
|
- **Docker**: we provide a number of pre-built Docker containers so you can start a Llama Stack server instantly. You can also build your own custom Docker container.
|
||||||
|
@ -30,11 +26,15 @@ If so, we suggest:
|
||||||
- {dockerhub}`distribution-ollama` ([Guide](self_hosted_distro/ollama))
|
- {dockerhub}`distribution-ollama` ([Guide](self_hosted_distro/ollama))
|
||||||
|
|
||||||
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
|
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
|
||||||
- {dockerhub}`distribution-together` ([Guide](remote_hosted_distro/index))
|
- {dockerhub}`distribution-together` ([Guide](self_hosted_distro/together))
|
||||||
- {dockerhub}`distribution-fireworks` ([Guide](remote_hosted_distro/index))
|
- {dockerhub}`distribution-fireworks` ([Guide](self_hosted_distro/fireworks))
|
||||||
|
|
||||||
- **Do you want to run Llama Stack inference on your iOS / Android device** If so, we suggest:
|
- **Do you want to run Llama Stack inference on your iOS / Android device** If so, we suggest:
|
||||||
- [iOS SDK](ondevice_distro/ios_sdk)
|
- [iOS SDK](ondevice_distro/ios_sdk)
|
||||||
- [Android](ondevice_distro/android_sdk)
|
- [Android](ondevice_distro/android_sdk)
|
||||||
|
|
||||||
|
- **Do you want a hosted Llama Stack endpoint?** If so, we suggest:
|
||||||
|
- [Remote-Hosted Llama Stack Endpoints](remote_hosted_distro/index)
|
||||||
|
|
||||||
|
|
||||||
You can also build your own [custom distribution](building_distro).
|
You can also build your own [custom distribution](building_distro).
|
||||||
|
|
|
@ -19,6 +19,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
|
||||||
| safety | `remote::bedrock` |
|
| safety | `remote::bedrock` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,7 +27,7 @@ The `llamastack/distribution-bedrock` distribution consists of the following pro
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,15 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
# Cerebras Distribution
|
# Cerebras Distribution
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 2
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
self
|
||||||
|
```
|
||||||
|
|
||||||
The `llamastack/distribution-cerebras` distribution consists of the following provider configurations.
|
The `llamastack/distribution-cerebras` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
| API | Provider(s) |
|
| API | Provider(s) |
|
||||||
|
@ -9,13 +19,14 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr
|
||||||
| memory | `inline::meta-reference` |
|
| memory | `inline::meta-reference` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
### Environment Variables
|
### Environment Variables
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
|
- `CEREBRAS_API_KEY`: Cerebras API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
|
|
@ -22,28 +22,30 @@ The `llamastack/distribution-fireworks` distribution consists of the following p
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
### Environment Variables
|
### Environment Variables
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``)
|
- `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta-llama/Llama-3.1-8B-Instruct (fireworks/llama-v3p1-8b-instruct)`
|
- `meta-llama/Llama-3.1-8B-Instruct (accounts/fireworks/models/llama-v3p1-8b-instruct)`
|
||||||
- `meta-llama/Llama-3.1-70B-Instruct (fireworks/llama-v3p1-70b-instruct)`
|
- `meta-llama/Llama-3.1-70B-Instruct (accounts/fireworks/models/llama-v3p1-70b-instruct)`
|
||||||
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (fireworks/llama-v3p1-405b-instruct)`
|
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (accounts/fireworks/models/llama-v3p1-405b-instruct)`
|
||||||
- `meta-llama/Llama-3.2-1B-Instruct (fireworks/llama-v3p2-1b-instruct)`
|
- `meta-llama/Llama-3.2-1B-Instruct (accounts/fireworks/models/llama-v3p2-1b-instruct)`
|
||||||
- `meta-llama/Llama-3.2-3B-Instruct (fireworks/llama-v3p2-3b-instruct)`
|
- `meta-llama/Llama-3.2-3B-Instruct (accounts/fireworks/models/llama-v3p2-3b-instruct)`
|
||||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct (fireworks/llama-v3p2-11b-vision-instruct)`
|
- `meta-llama/Llama-3.2-11B-Vision-Instruct (accounts/fireworks/models/llama-v3p2-11b-vision-instruct)`
|
||||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct (fireworks/llama-v3p2-90b-vision-instruct)`
|
- `meta-llama/Llama-3.2-90B-Vision-Instruct (accounts/fireworks/models/llama-v3p2-90b-vision-instruct)`
|
||||||
- `meta-llama/Llama-Guard-3-8B (fireworks/llama-guard-3-8b)`
|
- `meta-llama/Llama-3.3-70B-Instruct (accounts/fireworks/models/llama-v3p3-70b-instruct)`
|
||||||
- `meta-llama/Llama-Guard-3-11B-Vision (fireworks/llama-guard-3-11b-vision)`
|
- `meta-llama/Llama-Guard-3-8B (accounts/fireworks/models/llama-guard-3-8b)`
|
||||||
|
- `meta-llama/Llama-Guard-3-11B-Vision (accounts/fireworks/models/llama-guard-3-11b-vision)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -22,6 +22,7 @@ The `llamastack/distribution-meta-reference-gpu` distribution consists of the fo
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
|
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
|
||||||
|
@ -30,7 +31,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
||||||
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
||||||
|
|
|
@ -22,6 +22,7 @@ The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
|
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
|
||||||
|
@ -32,7 +33,7 @@ Note that you need access to nvidia GPUs to run this distribution. This distribu
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
||||||
|
|
||||||
|
|
|
@ -22,13 +22,14 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.### Environment Variables
|
You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.### Environment Variables
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
|
- `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
|
- `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
|
||||||
|
|
|
@ -18,6 +18,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
||||||
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
|
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
|
||||||
|
@ -26,9 +27,9 @@ You can use this distribution if you have GPUs and want to run an independent vL
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100}/v1`)
|
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100/v1`)
|
||||||
- `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`)
|
- `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`)
|
||||||
- `SAFETY_VLLM_URL`: URL of the vLLM server with the safety model (default: `http://host.docker.internal:5101/v1`)
|
- `SAFETY_VLLM_URL`: URL of the vLLM server with the safety model (default: `http://host.docker.internal:5101/v1`)
|
||||||
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
|
||||||
|
|
|
@ -23,6 +23,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
You can use this distribution if you have GPUs and want to run an independent TGI server container for running inference.
|
You can use this distribution if you have GPUs and want to run an independent TGI server container for running inference.
|
||||||
|
@ -31,7 +32,7 @@ You can use this distribution if you have GPUs and want to run an independent TG
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
||||||
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080}/v1`)
|
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080}/v1`)
|
||||||
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
|
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
|
||||||
|
|
|
@ -22,13 +22,14 @@ The `llamastack/distribution-together` distribution consists of the following pr
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::memory-runtime` |
|
||||||
|
|
||||||
|
|
||||||
### Environment Variables
|
### Environment Variables
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
- `TOGETHER_API_KEY`: Together.AI API Key (default: ``)
|
- `TOGETHER_API_KEY`: Together.AI API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
@ -41,6 +42,7 @@ The following models are available by default:
|
||||||
- `meta-llama/Llama-3.2-3B-Instruct`
|
- `meta-llama/Llama-3.2-3B-Instruct`
|
||||||
- `meta-llama/Llama-3.2-11B-Vision-Instruct`
|
- `meta-llama/Llama-3.2-11B-Vision-Instruct`
|
||||||
- `meta-llama/Llama-3.2-90B-Vision-Instruct`
|
- `meta-llama/Llama-3.2-90B-Vision-Instruct`
|
||||||
|
- `meta-llama/Llama-3.3-70B-Instruct`
|
||||||
- `meta-llama/Llama-Guard-3-8B`
|
- `meta-llama/Llama-Guard-3-8B`
|
||||||
- `meta-llama/Llama-Guard-3-11B-Vision`
|
- `meta-llama/Llama-Guard-3-11B-Vision`
|
||||||
|
|
||||||
|
|
|
@ -97,20 +97,20 @@ To download models, you can use the llama download command.
|
||||||
|
|
||||||
#### Downloading from [Meta](https://llama.meta.com/llama-downloads/)
|
#### Downloading from [Meta](https://llama.meta.com/llama-downloads/)
|
||||||
|
|
||||||
Here is an example download command to get the 3B-Instruct/11B-Vision-Instruct model. You will need META_URL which can be obtained from [here](https://llama.meta.com/docs/getting_the_models/meta/)
|
Here is an example download command to get the 3B-Instruct/11B-Vision-Instruct model. You will need META_URL which can be obtained from [here](https://llama.meta.com/docs/getting_the_models/meta/). Note: You need to quote the META_URL
|
||||||
|
|
||||||
Download the required checkpoints using the following commands:
|
Download the required checkpoints using the following commands:
|
||||||
```bash
|
```bash
|
||||||
# download the 8B model, this can be run on a single GPU
|
# download the 8B model, this can be run on a single GPU
|
||||||
llama download --source meta --model-id Llama3.2-3B-Instruct --meta-url META_URL
|
llama download --source meta --model-id Llama3.2-3B-Instruct --meta-url 'META_URL'
|
||||||
|
|
||||||
# you can also get the 70B model, this will require 8 GPUs however
|
# you can also get the 70B model, this will require 8 GPUs however
|
||||||
llama download --source meta --model-id Llama3.2-11B-Vision-Instruct --meta-url META_URL
|
llama download --source meta --model-id Llama3.2-11B-Vision-Instruct --meta-url 'META_URL'
|
||||||
|
|
||||||
# llama-agents have safety enabled by default. For this, you will need
|
# llama-agents have safety enabled by default. For this, you will need
|
||||||
# safety models -- Llama-Guard and Prompt-Guard
|
# safety models -- Llama-Guard and Prompt-Guard
|
||||||
llama download --source meta --model-id Prompt-Guard-86M --meta-url META_URL
|
llama download --source meta --model-id Prompt-Guard-86M --meta-url 'META_URL'
|
||||||
llama download --source meta --model-id Llama-Guard-3-1B --meta-url META_URL
|
llama download --source meta --model-id Llama-Guard-3-1B --meta-url 'META_URL'
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Downloading from [Hugging Face](https://huggingface.co/meta-llama)
|
#### Downloading from [Hugging Face](https://huggingface.co/meta-llama)
|
||||||
|
|
|
@ -89,7 +89,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
|
||||||
```
|
```
|
||||||
...
|
...
|
||||||
Build Successful! Next steps:
|
Build Successful! Next steps:
|
||||||
1. Set the environment variables: LLAMASTACK_PORT, OLLAMA_URL, INFERENCE_MODEL, SAFETY_MODEL
|
1. Set the environment variables: LLAMA_STACK_PORT, OLLAMA_URL, INFERENCE_MODEL, SAFETY_MODEL
|
||||||
2. `llama stack run /Users/<username>/.llama/distributions/llamastack-ollama/ollama-run.yaml
|
2. `llama stack run /Users/<username>/.llama/distributions/llamastack-ollama/ollama-run.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -18,15 +18,11 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||||
from llama_stack.apis.common.deployment_types import RestAPIExecutionConfig
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
@ -40,166 +36,18 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import MemoryBank
|
from llama_stack.apis.memory import MemoryBank
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
|
from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class Attachment(BaseModel):
|
class Attachment(BaseModel):
|
||||||
content: InterleavedContent | URL
|
content: InterleavedContent | URL
|
||||||
mime_type: str
|
mime_type: str
|
||||||
|
|
||||||
|
|
||||||
class AgentTool(Enum):
|
class Document(BaseModel):
|
||||||
brave_search = "brave_search"
|
content: InterleavedContent | URL
|
||||||
wolfram_alpha = "wolfram_alpha"
|
mime_type: str
|
||||||
photogen = "photogen"
|
|
||||||
code_interpreter = "code_interpreter"
|
|
||||||
|
|
||||||
function_call = "function_call"
|
|
||||||
memory = "memory"
|
|
||||||
|
|
||||||
|
|
||||||
class ToolDefinitionCommon(BaseModel):
|
|
||||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class SearchEngineType(Enum):
|
|
||||||
bing = "bing"
|
|
||||||
brave = "brave"
|
|
||||||
tavily = "tavily"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class SearchToolDefinition(ToolDefinitionCommon):
|
|
||||||
# NOTE: brave_search is just a placeholder since model always uses
|
|
||||||
# brave_search as tool call name
|
|
||||||
type: Literal[AgentTool.brave_search.value] = AgentTool.brave_search.value
|
|
||||||
api_key: str
|
|
||||||
engine: SearchEngineType = SearchEngineType.brave
|
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
|
||||||
type: Literal[AgentTool.wolfram_alpha.value] = AgentTool.wolfram_alpha.value
|
|
||||||
api_key: str
|
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PhotogenToolDefinition(ToolDefinitionCommon):
|
|
||||||
type: Literal[AgentTool.photogen.value] = AgentTool.photogen.value
|
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
|
||||||
type: Literal[AgentTool.code_interpreter.value] = AgentTool.code_interpreter.value
|
|
||||||
enable_inline_code_execution: bool = True
|
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
|
||||||
type: Literal[AgentTool.function_call.value] = AgentTool.function_call.value
|
|
||||||
function_name: str
|
|
||||||
description: str
|
|
||||||
parameters: Dict[str, ToolParamDefinition]
|
|
||||||
remote_execution: Optional[RestAPIExecutionConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
class _MemoryBankConfigCommon(BaseModel):
|
|
||||||
bank_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class AgentVectorMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["vector"] = "vector"
|
|
||||||
|
|
||||||
|
|
||||||
class AgentKeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["keyvalue"] = "keyvalue"
|
|
||||||
keys: List[str] # what keys to focus on
|
|
||||||
|
|
||||||
|
|
||||||
class AgentKeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["keyword"] = "keyword"
|
|
||||||
|
|
||||||
|
|
||||||
class AgentGraphMemoryBankConfig(_MemoryBankConfigCommon):
|
|
||||||
type: Literal["graph"] = "graph"
|
|
||||||
entities: List[str] # what entities to focus on
|
|
||||||
|
|
||||||
|
|
||||||
MemoryBankConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
AgentVectorMemoryBankConfig,
|
|
||||||
AgentKeyValueMemoryBankConfig,
|
|
||||||
AgentKeywordMemoryBankConfig,
|
|
||||||
AgentGraphMemoryBankConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryQueryGenerator(Enum):
|
|
||||||
default = "default"
|
|
||||||
llm = "llm"
|
|
||||||
custom = "custom"
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.default.value] = (
|
|
||||||
MemoryQueryGenerator.default.value
|
|
||||||
)
|
|
||||||
sep: str = " "
|
|
||||||
|
|
||||||
|
|
||||||
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
|
||||||
model: str
|
|
||||||
template: str
|
|
||||||
|
|
||||||
|
|
||||||
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
|
||||||
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
|
||||||
|
|
||||||
|
|
||||||
MemoryQueryGeneratorConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
|
||||||
LLMMemoryQueryGeneratorConfig,
|
|
||||||
CustomMemoryQueryGeneratorConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryToolDefinition(ToolDefinitionCommon):
|
|
||||||
type: Literal[AgentTool.memory.value] = AgentTool.memory.value
|
|
||||||
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
|
||||||
# This config defines how a query is generated using the messages
|
|
||||||
# for memory bank retrieval.
|
|
||||||
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
|
||||||
default=DefaultMemoryQueryGeneratorConfig()
|
|
||||||
)
|
|
||||||
max_tokens_in_context: int = 4096
|
|
||||||
max_chunks: int = 10
|
|
||||||
|
|
||||||
|
|
||||||
AgentToolDefinition = Annotated[
|
|
||||||
Union[
|
|
||||||
SearchToolDefinition,
|
|
||||||
WolframAlphaToolDefinition,
|
|
||||||
PhotogenToolDefinition,
|
|
||||||
CodeInterpreterToolDefinition,
|
|
||||||
FunctionCallToolDefinition,
|
|
||||||
MemoryToolDefinition,
|
|
||||||
],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class StepCommon(BaseModel):
|
class StepCommon(BaseModel):
|
||||||
|
@ -289,13 +137,27 @@ class Session(BaseModel):
|
||||||
memory_bank: Optional[MemoryBank] = None
|
memory_bank: Optional[MemoryBank] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentToolGroupWithArgs(BaseModel):
|
||||||
|
name: str
|
||||||
|
args: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
AgentToolGroup = register_schema(
|
||||||
|
Union[
|
||||||
|
str,
|
||||||
|
AgentToolGroupWithArgs,
|
||||||
|
],
|
||||||
|
name="AgentTool",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
|
||||||
tools: Optional[List[AgentToolDefinition]] = Field(default_factory=list)
|
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
|
@ -340,6 +202,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||||
AgentTurnResponseEventType.step_complete.value
|
AgentTurnResponseEventType.step_complete.value
|
||||||
)
|
)
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
|
step_id: str
|
||||||
step_details: Step
|
step_details: Step
|
||||||
|
|
||||||
|
|
||||||
|
@ -413,7 +276,9 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
attachments: Optional[List[Attachment]] = None
|
|
||||||
|
documents: Optional[List[Document]] = None
|
||||||
|
toolgroups: Optional[List[AgentToolGroup]] = None
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
@ -450,8 +315,9 @@ class Agents(Protocol):
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
attachments: Optional[List[Attachment]] = None,
|
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/turn/get")
|
@webmethod(route="/agents/turn/get")
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -44,9 +43,7 @@ class BatchChatCompletionRequest(BaseModel):
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||||
default=ToolPromptFormat.json
|
|
||||||
)
|
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,6 +72,6 @@ class BatchInference(Protocol):
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = list,
|
tools: Optional[List[ToolDefinition]] = list,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchChatCompletionResponse: ...
|
) -> BatchChatCompletionResponse: ...
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
|
@ -26,16 +25,12 @@ from llama_models.llama3.api.datatypes import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,7 +82,7 @@ class SystemMessage(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolResponseMessage(BaseModel):
|
class ToolResponseMessage(BaseModel):
|
||||||
role: Literal["ipython"] = "ipython"
|
role: Literal["tool"] = "tool"
|
||||||
# it was nice to re-use the ToolResponse type, but having all messages
|
# it was nice to re-use the ToolResponse type, but having all messages
|
||||||
# have a `content` type makes things nicer too
|
# have a `content` type makes things nicer too
|
||||||
call_id: str
|
call_id: str
|
||||||
|
@ -256,9 +251,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||||
default=ToolPromptFormat.json
|
|
||||||
)
|
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
@ -289,9 +282,7 @@ class BatchChatCompletionRequest(BaseModel):
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||||
default=ToolPromptFormat.json
|
|
||||||
)
|
|
||||||
logprobs: Optional[LogProbConfig] = None
|
logprobs: Optional[LogProbConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -334,7 +325,7 @@ class Inference(Protocol):
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
|
|
@ -29,6 +29,11 @@ class HealthInfo(BaseModel):
|
||||||
# TODO: add a provider level status
|
# TODO: add a provider level status
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class VersionInfo(BaseModel):
|
||||||
|
version: str
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Inspect(Protocol):
|
class Inspect(Protocol):
|
||||||
@webmethod(route="/providers/list", method="GET")
|
@webmethod(route="/providers/list", method="GET")
|
||||||
|
@ -39,3 +44,6 @@ class Inspect(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/health", method="GET")
|
@webmethod(route="/health", method="GET")
|
||||||
async def health(self) -> HealthInfo: ...
|
async def health(self) -> HealthInfo: ...
|
||||||
|
|
||||||
|
@webmethod(route="/version", method="GET")
|
||||||
|
async def version(self) -> VersionInfo: ...
|
||||||
|
|
|
@ -4,10 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
@ -21,59 +21,48 @@ class ToolParameter(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
parameter_type: str
|
parameter_type: str
|
||||||
description: str
|
description: str
|
||||||
|
required: bool = Field(default=True)
|
||||||
|
default: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ToolHost(Enum):
|
||||||
|
distribution = "distribution"
|
||||||
|
client = "client"
|
||||||
|
model_context_protocol = "model_context_protocol"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Tool(Resource):
|
class Tool(Resource):
|
||||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||||
tool_group: str
|
toolgroup_id: str
|
||||||
|
tool_host: ToolHost
|
||||||
description: str
|
description: str
|
||||||
parameters: List[ToolParameter]
|
parameters: List[ToolParameter]
|
||||||
provider_id: Optional[str] = None
|
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
|
||||||
default=ToolPromptFormat.json
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolDef(BaseModel):
|
class ToolDef(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: Optional[str] = None
|
||||||
parameters: List[ToolParameter]
|
parameters: Optional[List[ToolParameter]] = None
|
||||||
metadata: Dict[str, Any]
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
|
||||||
default=ToolPromptFormat.json
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MCPToolGroupDef(BaseModel):
|
class ToolGroupInput(BaseModel):
|
||||||
"""
|
toolgroup_id: str
|
||||||
A tool group that is defined by in a model context protocol server.
|
provider_id: str
|
||||||
Refer to https://modelcontextprotocol.io/docs/concepts/tools for more information.
|
args: Optional[Dict[str, Any]] = None
|
||||||
"""
|
mcp_endpoint: Optional[URL] = None
|
||||||
|
|
||||||
type: Literal["model_context_protocol"] = "model_context_protocol"
|
|
||||||
endpoint: URL
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class UserDefinedToolGroupDef(BaseModel):
|
|
||||||
type: Literal["user_defined"] = "user_defined"
|
|
||||||
tools: List[ToolDef]
|
|
||||||
|
|
||||||
|
|
||||||
ToolGroupDef = register_schema(
|
|
||||||
Annotated[
|
|
||||||
Union[MCPToolGroupDef, UserDefinedToolGroupDef], Field(discriminator="type")
|
|
||||||
],
|
|
||||||
name="ToolGroup",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolGroup(Resource):
|
class ToolGroup(Resource):
|
||||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||||
|
mcp_endpoint: Optional[URL] = None
|
||||||
|
args: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -85,6 +74,7 @@ class ToolInvocationResult(BaseModel):
|
||||||
|
|
||||||
class ToolStore(Protocol):
|
class ToolStore(Protocol):
|
||||||
def get_tool(self, tool_name: str) -> Tool: ...
|
def get_tool(self, tool_name: str) -> Tool: ...
|
||||||
|
def get_tool_group(self, tool_group_id: str) -> ToolGroup: ...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -93,9 +83,10 @@ class ToolGroups(Protocol):
|
||||||
@webmethod(route="/toolgroups/register", method="POST")
|
@webmethod(route="/toolgroups/register", method="POST")
|
||||||
async def register_tool_group(
|
async def register_tool_group(
|
||||||
self,
|
self,
|
||||||
tool_group_id: str,
|
toolgroup_id: str,
|
||||||
tool_group: ToolGroupDef,
|
provider_id: str,
|
||||||
provider_id: Optional[str] = None,
|
mcp_endpoint: Optional[URL] = None,
|
||||||
|
args: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a tool group"""
|
"""Register a tool group"""
|
||||||
...
|
...
|
||||||
|
@ -103,7 +94,7 @@ class ToolGroups(Protocol):
|
||||||
@webmethod(route="/toolgroups/get", method="GET")
|
@webmethod(route="/toolgroups/get", method="GET")
|
||||||
async def get_tool_group(
|
async def get_tool_group(
|
||||||
self,
|
self,
|
||||||
tool_group_id: str,
|
toolgroup_id: str,
|
||||||
) -> ToolGroup: ...
|
) -> ToolGroup: ...
|
||||||
|
|
||||||
@webmethod(route="/toolgroups/list", method="GET")
|
@webmethod(route="/toolgroups/list", method="GET")
|
||||||
|
@ -130,8 +121,11 @@ class ToolGroups(Protocol):
|
||||||
class ToolRuntime(Protocol):
|
class ToolRuntime(Protocol):
|
||||||
tool_store: ToolStore
|
tool_store: ToolStore
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/discover", method="POST")
|
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]: ...
|
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]: ...
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
|
|
|
@ -43,7 +43,7 @@ class ModelPromptFormat(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import pkg_resources
|
import importlib.resources
|
||||||
|
|
||||||
# Only Llama 3.1 and 3.2 are supported
|
# Only Llama 3.1 and 3.2 are supported
|
||||||
supported_model_ids = [
|
supported_model_ids = [
|
||||||
|
@ -64,25 +64,26 @@ class ModelPromptFormat(Subcommand):
|
||||||
f"{model_id} is not a valid Model. Choose one from --\n {model_str}"
|
f"{model_id} is not a valid Model. Choose one from --\n {model_str}"
|
||||||
)
|
)
|
||||||
|
|
||||||
llama_3_1_file = pkg_resources.resource_filename(
|
llama_3_1_file = (
|
||||||
"llama_models", "llama3_1/prompt_format.md"
|
importlib.resources.files("llama_models") / "llama3_1/prompt_format.md"
|
||||||
)
|
)
|
||||||
llama_3_2_text_file = pkg_resources.resource_filename(
|
llama_3_2_text_file = (
|
||||||
"llama_models", "llama3_2/text_prompt_format.md"
|
importlib.resources.files("llama_models") / "llama3_2/text_prompt_format.md"
|
||||||
)
|
)
|
||||||
llama_3_2_vision_file = pkg_resources.resource_filename(
|
llama_3_2_vision_file = (
|
||||||
"llama_models", "llama3_2/vision_prompt_format.md"
|
importlib.resources.files("llama_models")
|
||||||
|
/ "llama3_2/vision_prompt_format.md"
|
||||||
)
|
)
|
||||||
if model_family(model_id) == ModelFamily.llama3_1:
|
if model_family(model_id) == ModelFamily.llama3_1:
|
||||||
with open(llama_3_1_file, "r") as f:
|
with importlib.resources.as_file(llama_3_1_file) as f:
|
||||||
content = f.read()
|
content = f.open("r").read()
|
||||||
elif model_family(model_id) == ModelFamily.llama3_2:
|
elif model_family(model_id) == ModelFamily.llama3_2:
|
||||||
if is_multimodal(model_id):
|
if is_multimodal(model_id):
|
||||||
with open(llama_3_2_vision_file, "r") as f:
|
with importlib.resources.as_file(llama_3_2_vision_file) as f:
|
||||||
content = f.read()
|
content = f.open("r").read()
|
||||||
else:
|
else:
|
||||||
with open(llama_3_2_text_file, "r") as f:
|
with importlib.resources.as_file(llama_3_2_text_file) as f:
|
||||||
content = f.read()
|
content = f.open("r").read()
|
||||||
|
|
||||||
render_markdown_to_pager(content)
|
render_markdown_to_pager(content)
|
||||||
|
|
||||||
|
|
|
@ -4,14 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import importlib.resources
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pkg_resources
|
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
|
@ -290,13 +291,12 @@ class StackBuild(Subcommand):
|
||||||
|
|
||||||
if template_name:
|
if template_name:
|
||||||
# copy run.yaml from template to build_dir instead of generating it again
|
# copy run.yaml from template to build_dir instead of generating it again
|
||||||
template_path = pkg_resources.resource_filename(
|
template_path = (
|
||||||
"llama_stack", f"templates/{template_name}/run.yaml"
|
importlib.resources.files("llama_stack")
|
||||||
|
/ f"templates/{template_name}/run.yaml"
|
||||||
)
|
)
|
||||||
os.makedirs(build_dir, exist_ok=True)
|
with importlib.resources.as_file(template_path) as path:
|
||||||
run_config_file = build_dir / f"{build_config.name}-run.yaml"
|
shutil.copy(path, run_config_file)
|
||||||
shutil.copy(template_path, run_config_file)
|
|
||||||
|
|
||||||
# Find all ${env.VARIABLE} patterns
|
# Find all ${env.VARIABLE} patterns
|
||||||
cprint("Build Successful!", color="green")
|
cprint("Build Successful!", color="green")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
@ -34,7 +35,7 @@ class StackRun(Subcommand):
|
||||||
"--port",
|
"--port",
|
||||||
type=int,
|
type=int,
|
||||||
help="Port to run the server on. Defaults to 5000",
|
help="Port to run the server on. Defaults to 5000",
|
||||||
default=5000,
|
default=int(os.getenv("LLAMA_STACK_PORT", 5000)),
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--disable-ipv6",
|
"--disable-ipv6",
|
||||||
|
@ -51,7 +52,8 @@ class StackRun(Subcommand):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import pkg_resources
|
import importlib.resources
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_stack.distribution.build import ImageType
|
from llama_stack.distribution.build import ImageType
|
||||||
|
@ -106,15 +108,15 @@ class StackRun(Subcommand):
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
|
||||||
if config.docker_image:
|
if config.docker_image:
|
||||||
script = pkg_resources.resource_filename(
|
script = (
|
||||||
"llama_stack",
|
importlib.resources.files("llama_stack")
|
||||||
"distribution/start_container.sh",
|
/ "distribution/start_container.sh"
|
||||||
)
|
)
|
||||||
run_args = [script, config.docker_image]
|
run_args = [script, config.docker_image]
|
||||||
else:
|
else:
|
||||||
script = pkg_resources.resource_filename(
|
script = (
|
||||||
"llama_stack",
|
importlib.resources.files("llama_stack")
|
||||||
"distribution/start_conda_env.sh",
|
/ "distribution/start_conda_env.sh"
|
||||||
)
|
)
|
||||||
run_args = [
|
run_args = [
|
||||||
script,
|
script,
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from importlib.metadata import version
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
|
|
||||||
|
@ -24,6 +25,12 @@ class StackParser(Subcommand):
|
||||||
description="Operations for the Llama Stack / Distributions",
|
description="Operations for the Llama Stack / Distributions",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--version",
|
||||||
|
action="version",
|
||||||
|
version=f"{version('llama-stack')}",
|
||||||
|
)
|
||||||
|
|
||||||
subparsers = self.parser.add_subparsers(title="stack_subcommands")
|
subparsers = self.parser.add_subparsers(title="stack_subcommands")
|
||||||
|
|
||||||
# Add sub-commands
|
# Add sub-commands
|
||||||
|
|
|
@ -4,13 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import pkg_resources
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
@ -111,8 +111,8 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
normal_deps += SERVER_DEPENDENCIES
|
normal_deps += SERVER_DEPENDENCIES
|
||||||
|
|
||||||
if build_config.image_type == ImageType.docker.value:
|
if build_config.image_type == ImageType.docker.value:
|
||||||
script = pkg_resources.resource_filename(
|
script = (
|
||||||
"llama_stack", "distribution/build_container.sh"
|
importlib.resources.files("llama_stack") / "distribution/build_container.sh"
|
||||||
)
|
)
|
||||||
args = [
|
args = [
|
||||||
script,
|
script,
|
||||||
|
@ -123,8 +123,8 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
" ".join(normal_deps),
|
" ".join(normal_deps),
|
||||||
]
|
]
|
||||||
elif build_config.image_type == ImageType.conda.value:
|
elif build_config.image_type == ImageType.conda.value:
|
||||||
script = pkg_resources.resource_filename(
|
script = (
|
||||||
"llama_stack", "distribution/build_conda_env.sh"
|
importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh"
|
||||||
)
|
)
|
||||||
args = [
|
args = [
|
||||||
script,
|
script,
|
||||||
|
@ -133,9 +133,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
||||||
" ".join(normal_deps),
|
" ".join(normal_deps),
|
||||||
]
|
]
|
||||||
elif build_config.image_type == ImageType.venv.value:
|
elif build_config.image_type == ImageType.venv.value:
|
||||||
script = pkg_resources.resource_filename(
|
script = importlib.resources.files("llama_stack") / "distribution/build_venv.sh"
|
||||||
"llama_stack", "distribution/build_venv.sh"
|
|
||||||
)
|
|
||||||
args = [
|
args = [
|
||||||
script,
|
script,
|
||||||
build_config.name,
|
build_config.name,
|
||||||
|
|
|
@ -51,7 +51,19 @@ add_to_docker() {
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
add_to_docker <<EOF
|
# Update and install UBI9 components if UBI9 base image is used
|
||||||
|
if [[ $docker_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
||||||
|
add_to_docker <<EOF
|
||||||
|
FROM $docker_base
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN microdnf -y update && microdnf install -y iputils net-tools wget \
|
||||||
|
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
||||||
|
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all
|
||||||
|
|
||||||
|
EOF
|
||||||
|
else
|
||||||
|
add_to_docker <<EOF
|
||||||
FROM $docker_base
|
FROM $docker_base
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
@ -64,6 +76,7 @@ RUN apt-get update && apt-get install -y \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
EOF
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
# Add pip dependencies first since llama-stack is what will change most often
|
# Add pip dependencies first since llama-stack is what will change most often
|
||||||
# so we can reuse layers.
|
# so we can reuse layers.
|
||||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||||
from llama_stack.apis.shields import Shield, ShieldInput
|
from llama_stack.apis.shields import Shield, ShieldInput
|
||||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolRuntime
|
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||||
|
|
||||||
|
@ -161,6 +161,7 @@ a default SQLite store will be used.""",
|
||||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||||
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
|
||||||
|
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
|
|
|
@ -4,11 +4,18 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from importlib.metadata import version
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.inspect import HealthInfo, Inspect, ProviderInfo, RouteInfo
|
from llama_stack.apis.inspect import (
|
||||||
|
HealthInfo,
|
||||||
|
Inspect,
|
||||||
|
ProviderInfo,
|
||||||
|
RouteInfo,
|
||||||
|
VersionInfo,
|
||||||
|
)
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
|
@ -65,3 +72,6 @@ class DistributionInspectImpl(Inspect):
|
||||||
|
|
||||||
async def health(self) -> HealthInfo:
|
async def health(self) -> HealthInfo:
|
||||||
return HealthInfo(status="OK")
|
return HealthInfo(status="OK")
|
||||||
|
|
||||||
|
async def version(self) -> VersionInfo:
|
||||||
|
return VersionInfo(version=version("llama-stack"))
|
||||||
|
|
|
@ -33,6 +33,7 @@ from termcolor import cprint
|
||||||
from llama_stack.distribution.build import print_pip_install_help
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry
|
from llama_stack.distribution.resolver import ProviderRegistry
|
||||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
|
@ -67,6 +68,7 @@ def stream_across_asyncio_run_boundary(
|
||||||
async_gen_maker,
|
async_gen_maker,
|
||||||
pool_executor: ThreadPoolExecutor,
|
pool_executor: ThreadPoolExecutor,
|
||||||
path: Optional[str] = None,
|
path: Optional[str] = None,
|
||||||
|
provider_data: Optional[dict[str, Any]] = None,
|
||||||
) -> Generator[T, None, None]:
|
) -> Generator[T, None, None]:
|
||||||
result_queue = queue.Queue()
|
result_queue = queue.Queue()
|
||||||
stop_event = threading.Event()
|
stop_event = threading.Event()
|
||||||
|
@ -75,6 +77,10 @@ def stream_across_asyncio_run_boundary(
|
||||||
# make sure we make the generator in the event loop context
|
# make sure we make the generator in the event loop context
|
||||||
gen = await async_gen_maker()
|
gen = await async_gen_maker()
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
|
if provider_data:
|
||||||
|
set_request_provider_data(
|
||||||
|
{"X-LlamaStack-Provider-Data": json.dumps(provider_data)}
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
async for item in await gen:
|
async for item in await gen:
|
||||||
result_queue.put(item)
|
result_queue.put(item)
|
||||||
|
@ -174,6 +180,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
config_path_or_template_name: str,
|
config_path_or_template_name: str,
|
||||||
skip_logger_removal: bool = False,
|
skip_logger_removal: bool = False,
|
||||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||||
|
provider_data: Optional[dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
|
@ -181,6 +188,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
)
|
)
|
||||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||||
self.skip_logger_removal = skip_logger_removal
|
self.skip_logger_removal = skip_logger_removal
|
||||||
|
self.provider_data = provider_data
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
if in_notebook():
|
if in_notebook():
|
||||||
|
@ -219,10 +227,15 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
lambda: self.async_client.request(*args, **kwargs),
|
lambda: self.async_client.request(*args, **kwargs),
|
||||||
self.pool_executor,
|
self.pool_executor,
|
||||||
path=path,
|
path=path,
|
||||||
|
provider_data=self.provider_data,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def _traced_request():
|
async def _traced_request():
|
||||||
|
if self.provider_data:
|
||||||
|
set_request_provider_data(
|
||||||
|
{"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)}
|
||||||
|
)
|
||||||
await start_trace(path, {"__location__": "library_client"})
|
await start_trace(path, {"__location__": "library_client"})
|
||||||
try:
|
try:
|
||||||
return await self.async_client.request(*args, **kwargs)
|
return await self.async_client.request(*args, **kwargs)
|
||||||
|
@ -267,6 +280,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.config, self.custom_provider_registry
|
self.config, self.custom_provider_registry
|
||||||
)
|
)
|
||||||
except ModuleNotFoundError as _e:
|
except ModuleNotFoundError as _e:
|
||||||
|
cprint(_e.msg, "red")
|
||||||
cprint(
|
cprint(
|
||||||
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
||||||
"yellow",
|
"yellow",
|
||||||
|
|
|
@ -40,8 +40,8 @@ class NeedsRequestProviderData:
|
||||||
|
|
||||||
def set_request_provider_data(headers: Dict[str, str]):
|
def set_request_provider_data(headers: Dict[str, str]):
|
||||||
keys = [
|
keys = [
|
||||||
"X-LlamaStack-ProviderData",
|
"X-LlamaStack-Provider-Data",
|
||||||
"x-llamastack-providerdata",
|
"x-llamastack-provider-data",
|
||||||
]
|
]
|
||||||
for key in keys:
|
for key in keys:
|
||||||
val = headers.get(key, None)
|
val = headers.get(key, None)
|
||||||
|
|
|
@ -5,9 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
@ -28,7 +26,6 @@ from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.distribution.client import get_client_impl
|
from llama_stack.distribution.client import get_client_impl
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
AutoRoutedProviderSpec,
|
AutoRoutedProviderSpec,
|
||||||
Provider,
|
Provider,
|
||||||
|
@ -38,7 +35,6 @@ from llama_stack.distribution.datatypes import (
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
Api,
|
Api,
|
||||||
DatasetsProtocolPrivate,
|
DatasetsProtocolPrivate,
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent, URL
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||||
from llama_stack.apis.eval import (
|
from llama_stack.apis.eval import (
|
||||||
AppEvalTaskConfig,
|
AppEvalTaskConfig,
|
||||||
|
@ -38,7 +38,7 @@ from llama_stack.apis.scoring import (
|
||||||
ScoringFnParams,
|
ScoringFnParams,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolRuntime
|
from llama_stack.apis.tools import ToolDef, ToolRuntime
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,7 +127,7 @@ class InferenceRouter(Inference):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -417,7 +417,9 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
args=args,
|
args=args,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
async def list_runtime_tools(
|
||||||
return await self.routing_table.get_provider_impl(
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
tool_group.name
|
) -> List[ToolDef]:
|
||||||
).discover_tools(tool_group)
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
|
||||||
|
tool_group_id, mcp_endpoint
|
||||||
|
)
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import parse_obj_as
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
@ -26,20 +26,12 @@ from llama_stack.apis.scoring_functions import (
|
||||||
ScoringFunctions,
|
ScoringFunctions,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield, Shields
|
from llama_stack.apis.shields import Shield, Shields
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroups, ToolHost
|
||||||
MCPToolGroupDef,
|
|
||||||
Tool,
|
|
||||||
ToolGroup,
|
|
||||||
ToolGroupDef,
|
|
||||||
ToolGroups,
|
|
||||||
UserDefinedToolGroupDef,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
RoutableObject,
|
RoutableObject,
|
||||||
RoutableObjectWithProvider,
|
RoutableObjectWithProvider,
|
||||||
RoutedProtocol,
|
RoutedProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
|
@ -361,7 +353,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
memory_bank_data["embedding_dimension"] = model.metadata[
|
memory_bank_data["embedding_dimension"] = model.metadata[
|
||||||
"embedding_dimension"
|
"embedding_dimension"
|
||||||
]
|
]
|
||||||
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
|
memory_bank = TypeAdapter(MemoryBank).validate_python(memory_bank_data)
|
||||||
await self.register_object(memory_bank)
|
await self.register_object(memory_bank)
|
||||||
return memory_bank
|
return memory_bank
|
||||||
|
|
||||||
|
@ -496,54 +488,44 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||||
tools = await self.get_all_with_type("tool")
|
tools = await self.get_all_with_type("tool")
|
||||||
if tool_group_id:
|
if tool_group_id:
|
||||||
tools = [tool for tool in tools if tool.tool_group == tool_group_id]
|
tools = [tool for tool in tools if tool.toolgroup_id == tool_group_id]
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||||
return await self.get_all_with_type("tool_group")
|
return await self.get_all_with_type("tool_group")
|
||||||
|
|
||||||
async def get_tool_group(self, tool_group_id: str) -> ToolGroup:
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||||
return await self.get_object_by_identifier("tool_group", tool_group_id)
|
return await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||||
|
|
||||||
async def get_tool(self, tool_name: str) -> Tool:
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
return await self.get_object_by_identifier("tool", tool_name)
|
return await self.get_object_by_identifier("tool", tool_name)
|
||||||
|
|
||||||
async def register_tool_group(
|
async def register_tool_group(
|
||||||
self,
|
self,
|
||||||
tool_group_id: str,
|
toolgroup_id: str,
|
||||||
tool_group: ToolGroupDef,
|
provider_id: str,
|
||||||
provider_id: Optional[str] = None,
|
mcp_endpoint: Optional[URL] = None,
|
||||||
|
args: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
tools = []
|
||||||
tool_defs = []
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
|
||||||
if provider_id is None:
|
toolgroup_id, mcp_endpoint
|
||||||
if len(self.impls_by_provider_id.keys()) > 1:
|
)
|
||||||
raise ValueError(
|
tool_host = (
|
||||||
f"No provider_id specified and multiple providers available. Please specify a provider_id. Available providers: {', '.join(self.impls_by_provider_id.keys())}"
|
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||||
)
|
)
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
|
|
||||||
if isinstance(tool_group, MCPToolGroupDef):
|
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].discover_tools(
|
|
||||||
tool_group
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(tool_group, UserDefinedToolGroupDef):
|
|
||||||
tool_defs = tool_group.tools
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown tool group: {tool_group}")
|
|
||||||
|
|
||||||
for tool_def in tool_defs:
|
for tool_def in tool_defs:
|
||||||
tools.append(
|
tools.append(
|
||||||
Tool(
|
Tool(
|
||||||
identifier=tool_def.name,
|
identifier=tool_def.name,
|
||||||
tool_group=tool_group_id,
|
toolgroup_id=toolgroup_id,
|
||||||
description=tool_def.description,
|
description=tool_def.description or "",
|
||||||
parameters=tool_def.parameters,
|
parameters=tool_def.parameters or [],
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
tool_prompt_format=tool_def.tool_prompt_format,
|
|
||||||
provider_resource_id=tool_def.name,
|
provider_resource_id=tool_def.name,
|
||||||
metadata=tool_def.metadata,
|
metadata=tool_def.metadata,
|
||||||
|
tool_host=tool_host,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
|
@ -561,9 +543,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
|
||||||
await self.dist_registry.register(
|
await self.dist_registry.register(
|
||||||
ToolGroup(
|
ToolGroup(
|
||||||
identifier=tool_group_id,
|
identifier=toolgroup_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_resource_id=tool_group_id,
|
provider_resource_id=toolgroup_id,
|
||||||
|
mcp_endpoint=mcp_endpoint,
|
||||||
|
args=args,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,8 @@ import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from importlib.metadata import version as parse_version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
|
@ -228,6 +230,52 @@ class TracingMiddleware:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
|
|
||||||
|
class ClientVersionMiddleware:
|
||||||
|
def __init__(self, app):
|
||||||
|
self.app = app
|
||||||
|
self.server_version = parse_version("llama-stack")
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
if scope["type"] == "http":
|
||||||
|
headers = dict(scope.get("headers", []))
|
||||||
|
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
|
||||||
|
if client_version:
|
||||||
|
try:
|
||||||
|
client_version_parts = tuple(
|
||||||
|
map(int, client_version.split(".")[:2])
|
||||||
|
)
|
||||||
|
server_version_parts = tuple(
|
||||||
|
map(int, self.server_version.split(".")[:2])
|
||||||
|
)
|
||||||
|
if client_version_parts != server_version_parts:
|
||||||
|
|
||||||
|
async def send_version_error(send):
|
||||||
|
await send(
|
||||||
|
{
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 426,
|
||||||
|
"headers": [[b"content-type", b"application/json"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
error_msg = json.dumps(
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please upgrade your client."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
).encode()
|
||||||
|
await send(
|
||||||
|
{"type": "http.response.body", "body": error_msg}
|
||||||
|
)
|
||||||
|
|
||||||
|
return await send_version_error(send)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
# If version parsing fails, let the request through
|
||||||
|
pass
|
||||||
|
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Start the LlamaStack server."""
|
"""Start the LlamaStack server."""
|
||||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||||
|
@ -242,7 +290,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port",
|
"--port",
|
||||||
type=int,
|
type=int,
|
||||||
default=int(os.getenv("LLAMASTACK_PORT", 5000)),
|
default=int(os.getenv("LLAMA_STACK_PORT", 5000)),
|
||||||
help="Port to listen on",
|
help="Port to listen on",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -291,6 +339,7 @@ def main():
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
app.add_middleware(TracingMiddleware)
|
app.add_middleware(TracingMiddleware)
|
||||||
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
|
|
|
@ -4,15 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import pkg_resources
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
@ -33,14 +31,13 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||||
from llama_stack.distribution.store.registry import create_dist_registry
|
from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
LLAMA_STACK_API_VERSION = "alpha"
|
LLAMA_STACK_API_VERSION = "alpha"
|
||||||
|
@ -65,6 +62,8 @@ class LlamaStack(
|
||||||
Models,
|
Models,
|
||||||
Shields,
|
Shields,
|
||||||
Inspect,
|
Inspect,
|
||||||
|
ToolGroups,
|
||||||
|
ToolRuntime,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -81,6 +80,7 @@ RESOURCES = [
|
||||||
"list_scoring_functions",
|
"list_scoring_functions",
|
||||||
),
|
),
|
||||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||||
|
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -210,14 +210,13 @@ async def construct_stack(
|
||||||
|
|
||||||
|
|
||||||
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
||||||
template_path = pkg_resources.resource_filename(
|
template_path = (
|
||||||
"llama_stack", f"templates/{template}/run.yaml"
|
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not Path(template_path).exists():
|
with importlib.resources.as_file(template_path) as path:
|
||||||
raise ValueError(f"Template '{template}' not found at {template_path}")
|
if not path.exists():
|
||||||
|
raise ValueError(f"Template '{template}' not found at {template_path}")
|
||||||
with open(template_path) as f:
|
run_config = yaml.safe_load(path.open())
|
||||||
run_config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
return StackRunConfig(**replace_env_vars(run_config))
|
return StackRunConfig(**replace_env_vars(run_config))
|
||||||
|
|
|
@ -90,6 +90,6 @@ $DOCKER_BINARY run $DOCKER_OPTS -it \
|
||||||
$env_vars \
|
$env_vars \
|
||||||
-v "$yaml_config:/app/config.yaml" \
|
-v "$yaml_config:/app/config.yaml" \
|
||||||
$mounts \
|
$mounts \
|
||||||
--env LLAMASTACK_PORT=$port \
|
--env LLAMA_STACK_PORT=$port \
|
||||||
--entrypoint='["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]' \
|
--entrypoint='["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]' \
|
||||||
$docker_image:$version_tag
|
$docker_image:$version_tag
|
||||||
|
|
|
@ -12,7 +12,6 @@ import pydantic
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
@ -36,7 +35,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v3"
|
KEY_VERSION = "v4"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,8 @@ async def get_provider_impl(
|
||||||
deps[Api.memory],
|
deps[Api.memory],
|
||||||
deps[Api.safety],
|
deps[Api.safety],
|
||||||
deps[Api.memory_banks],
|
deps[Api.memory_banks],
|
||||||
|
deps[Api.tool_runtime],
|
||||||
|
deps[Api.tool_groups],
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,8 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import copy
|
import copy
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
@ -13,16 +13,16 @@ import secrets
|
||||||
import string
|
import string
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
AgentTool,
|
AgentToolGroup,
|
||||||
|
AgentToolGroupWithArgs,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
AgentTurnResponseEvent,
|
AgentTurnResponseEvent,
|
||||||
AgentTurnResponseEventType,
|
AgentTurnResponseEventType,
|
||||||
|
@ -33,25 +33,14 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
Attachment,
|
Attachment,
|
||||||
CodeInterpreterToolDefinition,
|
Document,
|
||||||
FunctionCallToolDefinition,
|
|
||||||
InferenceStep,
|
InferenceStep,
|
||||||
MemoryRetrievalStep,
|
|
||||||
MemoryToolDefinition,
|
|
||||||
PhotogenToolDefinition,
|
|
||||||
SearchToolDefinition,
|
|
||||||
ShieldCallStep,
|
ShieldCallStep,
|
||||||
StepType,
|
StepType,
|
||||||
ToolExecutionStep,
|
ToolExecutionStep,
|
||||||
Turn,
|
Turn,
|
||||||
WolframAlphaToolDefinition,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
|
||||||
InterleavedContent,
|
|
||||||
TextContentItem,
|
|
||||||
URL,
|
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.common.content_types import TextContentItem, URL
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
|
@ -62,32 +51,20 @@ from llama_stack.apis.inference import (
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
ToolChoice,
|
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
|
from llama_stack.apis.memory import Memory, MemoryBankDocument
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
from .persistence import AgentPersistence
|
from .persistence import AgentPersistence
|
||||||
from .rag.context_retriever import generate_rag_query
|
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
from .tools.base import BaseTool
|
|
||||||
from .tools.builtin import (
|
|
||||||
CodeInterpreterTool,
|
|
||||||
interpret_content_as_attachment,
|
|
||||||
PhotogenTool,
|
|
||||||
SearchTool,
|
|
||||||
WolframAlphaTool,
|
|
||||||
)
|
|
||||||
from .tools.safety import SafeTool
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -98,6 +75,12 @@ def make_random_string(length: int = 8):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
|
MEMORY_QUERY_TOOL = "query_memory"
|
||||||
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
|
MEMORY_GROUP = "builtin::memory"
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(ShieldRunnerMixin):
|
class ChatAgent(ShieldRunnerMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -108,6 +91,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
memory_api: Memory,
|
memory_api: Memory,
|
||||||
memory_banks_api: MemoryBanks,
|
memory_banks_api: MemoryBanks,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
|
tool_runtime_api: ToolRuntime,
|
||||||
|
tool_groups_api: ToolGroups,
|
||||||
persistence_store: KVStore,
|
persistence_store: KVStore,
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
|
@ -118,29 +103,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self.memory_banks_api = memory_banks_api
|
self.memory_banks_api = memory_banks_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||||
|
self.tool_runtime_api = tool_runtime_api
|
||||||
builtin_tools = []
|
self.tool_groups_api = tool_groups_api
|
||||||
for tool_defn in agent_config.tools:
|
|
||||||
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
|
||||||
tool = WolframAlphaTool(tool_defn.api_key)
|
|
||||||
elif isinstance(tool_defn, SearchToolDefinition):
|
|
||||||
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
|
|
||||||
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
|
||||||
tool = CodeInterpreterTool()
|
|
||||||
elif isinstance(tool_defn, PhotogenToolDefinition):
|
|
||||||
tool = PhotogenTool(dump_dir=self.tempdir)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
builtin_tools.append(
|
|
||||||
SafeTool(
|
|
||||||
tool,
|
|
||||||
safety_api,
|
|
||||||
tool_defn.input_shields,
|
|
||||||
tool_defn.output_shields,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.tools_dict = {t.get_name(): t for t in builtin_tools}
|
|
||||||
|
|
||||||
ShieldRunnerMixin.__init__(
|
ShieldRunnerMixin.__init__(
|
||||||
self,
|
self,
|
||||||
|
@ -228,9 +192,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
attachments=request.attachments or [],
|
|
||||||
sampling_params=self.agent_config.sampling_params,
|
sampling_params=self.agent_config.sampling_params,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
|
documents=request.documents,
|
||||||
|
toolgroups_for_turn=request.toolgroups,
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
log.info(
|
log.info(
|
||||||
|
@ -278,9 +243,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
attachments: List[Attachment],
|
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
|
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||||
|
@ -297,7 +263,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
async for res in self._run(
|
async for res in self._run(
|
||||||
session_id, turn_id, input_messages, attachments, sampling_params, stream
|
session_id,
|
||||||
|
turn_id,
|
||||||
|
input_messages,
|
||||||
|
sampling_params,
|
||||||
|
stream,
|
||||||
|
documents,
|
||||||
|
toolgroups_for_turn,
|
||||||
):
|
):
|
||||||
if isinstance(res, bool):
|
if isinstance(res, bool):
|
||||||
return
|
return
|
||||||
|
@ -353,6 +325,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call.value,
|
||||||
|
step_id=step_id,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
|
@ -373,6 +346,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call.value,
|
||||||
|
step_id=step_id,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
|
@ -388,73 +362,116 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
attachments: List[Attachment],
|
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
|
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
toolgroup_args = {}
|
||||||
need_rag_context = await self._should_retrieve_context(
|
for toolgroup in self.agent_config.toolgroups:
|
||||||
input_messages, attachments
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
)
|
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||||
if need_rag_context:
|
if toolgroups_for_turn:
|
||||||
step_id = str(uuid.uuid4())
|
for toolgroup in toolgroups_for_turn:
|
||||||
yield AgentTurnResponseStreamChunk(
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
event=AgentTurnResponseEvent(
|
toolgroup_args[toolgroup.name] = toolgroup.args
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
|
||||||
step_type=StepType.memory_retrieval.value,
|
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
|
||||||
step_id=step_id,
|
if documents:
|
||||||
|
await self.handle_documents(
|
||||||
|
session_id, documents, input_messages, tool_defs
|
||||||
|
)
|
||||||
|
if MEMORY_QUERY_TOOL in tool_defs and len(input_messages) > 0:
|
||||||
|
memory_tool_group = tool_to_group.get(MEMORY_QUERY_TOOL, None)
|
||||||
|
if memory_tool_group is None:
|
||||||
|
raise ValueError(f"Memory tool group not found for {MEMORY_QUERY_TOOL}")
|
||||||
|
with tracing.span(MEMORY_QUERY_TOOL) as span:
|
||||||
|
step_id = str(uuid.uuid4())
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=step_id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
query_args = {
|
||||||
|
"messages": [msg.content for msg in input_messages],
|
||||||
|
**toolgroup_args.get(memory_tool_group, {}),
|
||||||
|
}
|
||||||
|
|
||||||
# TODO: find older context from the session and either replace it
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
# or append with a sliding window. this is really a very simplistic implementation
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
with tracing.span("retrieve_rag_context") as span:
|
if session_info.memory_bank_id:
|
||||||
rag_context, bank_ids = await self._retrieve_context(
|
if "memory_bank_ids" not in query_args:
|
||||||
session_id, input_messages, attachments
|
query_args["memory_bank_ids"] = []
|
||||||
|
query_args["memory_bank_ids"].append(session_info.memory_bank_id)
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=step_id,
|
||||||
|
tool_call_delta=ToolCallDelta(
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
content=ToolCall(
|
||||||
|
call_id="",
|
||||||
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
|
arguments={},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
|
args=query_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=step_id,
|
||||||
|
step_details=ToolExecutionStep(
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="",
|
||||||
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
|
arguments={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tool_responses=[
|
||||||
|
ToolResponse(
|
||||||
|
call_id="",
|
||||||
|
tool_name=MEMORY_QUERY_TOOL,
|
||||||
|
content=result.content,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
span.set_attribute(
|
span.set_attribute(
|
||||||
"input", [m.model_dump_json() for m in input_messages]
|
"input", [m.model_dump_json() for m in input_messages]
|
||||||
)
|
)
|
||||||
span.set_attribute("output", rag_context)
|
span.set_attribute("output", result.content)
|
||||||
span.set_attribute("bank_ids", bank_ids)
|
span.set_attribute("error_code", result.error_code)
|
||||||
|
span.set_attribute("error_message", result.error_message)
|
||||||
step_id = str(uuid.uuid4())
|
span.set_attribute("tool_name", MEMORY_QUERY_TOOL)
|
||||||
yield AgentTurnResponseStreamChunk(
|
if result.error_code == 0:
|
||||||
event=AgentTurnResponseEvent(
|
last_message = input_messages[-1]
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
last_message.context = result.content
|
||||||
step_type=StepType.memory_retrieval.value,
|
|
||||||
step_id=step_id,
|
|
||||||
step_details=MemoryRetrievalStep(
|
|
||||||
turn_id=turn_id,
|
|
||||||
step_id=step_id,
|
|
||||||
memory_bank_ids=bank_ids,
|
|
||||||
inserted_context=rag_context or "",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if rag_context:
|
|
||||||
last_message = input_messages[-1]
|
|
||||||
last_message.context = rag_context
|
|
||||||
|
|
||||||
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
|
|
||||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
|
||||||
# TODO: we need to migrate URL away from str type
|
|
||||||
pattern = re.compile("^(https?://|file://|data:)")
|
|
||||||
urls += [
|
|
||||||
URL(uri=a.content) for a in attachments if pattern.match(a.content)
|
|
||||||
]
|
|
||||||
msg = await attachment_message(self.tempdir, urls)
|
|
||||||
input_messages.append(msg)
|
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
n_iter = 0
|
n_iter = 0
|
||||||
|
# Build a map of custom tools to their definitions for faster lookup
|
||||||
|
client_tools = {}
|
||||||
|
for tool in self.agent_config.client_tools:
|
||||||
|
client_tools[tool.name] = tool
|
||||||
while True:
|
while True:
|
||||||
msg = input_messages[-1]
|
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -473,7 +490,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=self._get_tools(),
|
tools=[
|
||||||
|
tool
|
||||||
|
for tool in tool_defs.values()
|
||||||
|
if tool_to_group.get(tool.tool_name, None) != MEMORY_GROUP
|
||||||
|
],
|
||||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
@ -572,9 +593,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||||
if len(output_attachments) > 0:
|
if len(output_attachments) > 0:
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
message.content += attachments
|
message.content += output_attachments
|
||||||
else:
|
else:
|
||||||
message.content = [message.content] + attachments
|
message.content = [message.content] + output_attachments
|
||||||
yield message
|
yield message
|
||||||
else:
|
else:
|
||||||
log.info(f"Partial message: {str(message)}")
|
log.info(f"Partial message: {str(message)}")
|
||||||
|
@ -582,9 +603,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
log.info(f"{str(message)}")
|
log.info(f"{str(message)}")
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
|
if tool_call.tool_name in client_tools:
|
||||||
name = tool_call.tool_name
|
|
||||||
if not isinstance(name, BuiltinTool) or name not in enabled_tools:
|
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -607,16 +626,22 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool_name = tool_call.tool_name
|
||||||
|
if isinstance(tool_name, BuiltinTool):
|
||||||
|
tool_name = tool_name.value
|
||||||
with tracing.span(
|
with tracing.span(
|
||||||
"tool_execution",
|
"tool_execution",
|
||||||
{
|
{
|
||||||
"tool_name": tool_call.tool_name,
|
"tool_name": tool_name,
|
||||||
"input": message.model_dump_json(),
|
"input": message.model_dump_json(),
|
||||||
},
|
},
|
||||||
) as span:
|
) as span:
|
||||||
result_messages = await execute_tool_call_maybe(
|
result_messages = await execute_tool_call_maybe(
|
||||||
self.tools_dict,
|
self.tool_runtime_api,
|
||||||
|
session_id,
|
||||||
[message],
|
[message],
|
||||||
|
toolgroup_args,
|
||||||
|
tool_to_group,
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
len(result_messages) == 1
|
len(result_messages) == 1
|
||||||
|
@ -628,6 +653,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=step_id,
|
||||||
step_details=ToolExecutionStep(
|
step_details=ToolExecutionStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
|
@ -647,7 +673,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||||
# but that needs a lot more refactoring of Tool code potentially
|
# but that needs a lot more refactoring of Tool code potentially
|
||||||
|
|
||||||
if out_attachment := interpret_content_as_attachment(
|
if out_attachment := _interpret_content_as_attachment(
|
||||||
result_message.content
|
result_message.content
|
||||||
):
|
):
|
||||||
# NOTE: when we push this message back to the model, the model may ignore the
|
# NOTE: when we push this message back to the model, the model may ignore the
|
||||||
|
@ -659,6 +685,150 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
|
||||||
|
async def _get_tool_defs(
|
||||||
|
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||||
|
) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]:
|
||||||
|
# Determine which tools to include
|
||||||
|
agent_config_toolgroups = set(
|
||||||
|
(
|
||||||
|
toolgroup.name
|
||||||
|
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||||
|
else toolgroup
|
||||||
|
)
|
||||||
|
for toolgroup in self.agent_config.toolgroups
|
||||||
|
)
|
||||||
|
toolgroups_for_turn_set = (
|
||||||
|
agent_config_toolgroups
|
||||||
|
if toolgroups_for_turn is None
|
||||||
|
else {
|
||||||
|
(
|
||||||
|
toolgroup.name
|
||||||
|
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||||
|
else toolgroup
|
||||||
|
)
|
||||||
|
for toolgroup in toolgroups_for_turn
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_def_map = {}
|
||||||
|
tool_to_group = {}
|
||||||
|
|
||||||
|
for tool_def in self.agent_config.client_tools:
|
||||||
|
if tool_def_map.get(tool_def.name, None):
|
||||||
|
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||||
|
tool_def_map[tool_def.name] = ToolDefinition(
|
||||||
|
tool_name=tool_def.name,
|
||||||
|
description=tool_def.description,
|
||||||
|
parameters={
|
||||||
|
param.name: ToolParamDefinition(
|
||||||
|
param_type=param.parameter_type,
|
||||||
|
description=param.description,
|
||||||
|
required=param.required,
|
||||||
|
default=param.default,
|
||||||
|
)
|
||||||
|
for param in tool_def.parameters
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tool_to_group[tool_def.name] = "__client_tools__"
|
||||||
|
for toolgroup_name in agent_config_toolgroups:
|
||||||
|
if toolgroup_name not in toolgroups_for_turn_set:
|
||||||
|
continue
|
||||||
|
tools = await self.tool_groups_api.list_tools(tool_group_id=toolgroup_name)
|
||||||
|
for tool_def in tools:
|
||||||
|
if (
|
||||||
|
toolgroup_name.startswith("builtin")
|
||||||
|
and toolgroup_name != MEMORY_GROUP
|
||||||
|
):
|
||||||
|
tool_name = tool_def.identifier
|
||||||
|
built_in_type = BuiltinTool.brave_search
|
||||||
|
if tool_name == "web_search":
|
||||||
|
built_in_type = BuiltinTool.brave_search
|
||||||
|
else:
|
||||||
|
built_in_type = BuiltinTool(tool_name)
|
||||||
|
|
||||||
|
if tool_def_map.get(built_in_type, None):
|
||||||
|
raise ValueError(f"Tool {built_in_type} already exists")
|
||||||
|
|
||||||
|
tool_def_map[built_in_type] = ToolDefinition(
|
||||||
|
tool_name=built_in_type
|
||||||
|
)
|
||||||
|
tool_to_group[built_in_type] = tool_def.toolgroup_id
|
||||||
|
continue
|
||||||
|
|
||||||
|
if tool_def_map.get(tool_def.identifier, None):
|
||||||
|
raise ValueError(f"Tool {tool_def.identifier} already exists")
|
||||||
|
tool_def_map[tool_def.identifier] = ToolDefinition(
|
||||||
|
tool_name=tool_def.identifier,
|
||||||
|
description=tool_def.description,
|
||||||
|
parameters={
|
||||||
|
param.name: ToolParamDefinition(
|
||||||
|
param_type=param.parameter_type,
|
||||||
|
description=param.description,
|
||||||
|
required=param.required,
|
||||||
|
default=param.default,
|
||||||
|
)
|
||||||
|
for param in tool_def.parameters
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
|
||||||
|
|
||||||
|
return tool_def_map, tool_to_group
|
||||||
|
|
||||||
|
async def handle_documents(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
documents: List[Document],
|
||||||
|
input_messages: List[Message],
|
||||||
|
tool_defs: Dict[str, ToolDefinition],
|
||||||
|
) -> None:
|
||||||
|
memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None)
|
||||||
|
code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None)
|
||||||
|
content_items = []
|
||||||
|
url_items = []
|
||||||
|
pattern = re.compile("^(https?://|file://|data:)")
|
||||||
|
for d in documents:
|
||||||
|
if isinstance(d.content, URL):
|
||||||
|
url_items.append(d.content)
|
||||||
|
elif pattern.match(d.content):
|
||||||
|
url_items.append(URL(uri=d.content))
|
||||||
|
else:
|
||||||
|
content_items.append(d)
|
||||||
|
|
||||||
|
# Save the contents to a tempdir and use its path as a URL if code interpreter is present
|
||||||
|
if code_interpreter_tool:
|
||||||
|
for c in content_items:
|
||||||
|
temp_file_path = os.path.join(
|
||||||
|
self.tempdir, f"{make_random_string()}.txt"
|
||||||
|
)
|
||||||
|
with open(temp_file_path, "w") as temp_file:
|
||||||
|
temp_file.write(c.content)
|
||||||
|
url_items.append(URL(uri=f"file://{temp_file_path}"))
|
||||||
|
|
||||||
|
if memory_tool and code_interpreter_tool:
|
||||||
|
# if both memory and code_interpreter are available, we download the URLs
|
||||||
|
# and attach the data to the last message.
|
||||||
|
msg = await attachment_message(self.tempdir, url_items)
|
||||||
|
input_messages.append(msg)
|
||||||
|
# Since memory is present, add all the data to the memory bank
|
||||||
|
await self.add_to_session_memory_bank(session_id, documents)
|
||||||
|
elif code_interpreter_tool:
|
||||||
|
# if only code_interpreter is available, we download the URLs to a tempdir
|
||||||
|
# and attach the path to them as a message to inference with the
|
||||||
|
# assumption that the model invokes the code_interpreter tool with the path
|
||||||
|
msg = await attachment_message(self.tempdir, url_items)
|
||||||
|
input_messages.append(msg)
|
||||||
|
elif memory_tool:
|
||||||
|
# if only memory is available, we load the data from the URLs and content items to the memory bank
|
||||||
|
await self.add_to_session_memory_bank(session_id, documents)
|
||||||
|
else:
|
||||||
|
# if no memory or code_interpreter tool is available,
|
||||||
|
# we try to load the data from the URLs and content items as a message to inference
|
||||||
|
# and add it to the last message's context
|
||||||
|
input_messages[-1].context = "\n".join(
|
||||||
|
[doc.content for doc in content_items]
|
||||||
|
+ await load_data_from_urls(url_items)
|
||||||
|
)
|
||||||
|
|
||||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
if session_info is None:
|
if session_info is None:
|
||||||
|
@ -679,129 +849,39 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
return bank_id
|
return bank_id
|
||||||
|
|
||||||
async def _should_retrieve_context(
|
async def add_to_session_memory_bank(
|
||||||
self, messages: List[Message], attachments: List[Attachment]
|
self, session_id: str, data: List[Document]
|
||||||
) -> bool:
|
) -> None:
|
||||||
enabled_tools = set(t.type for t in self.agent_config.tools)
|
bank_id = await self._ensure_memory_bank(session_id)
|
||||||
if attachments:
|
documents = [
|
||||||
if (
|
MemoryBankDocument(
|
||||||
AgentTool.code_interpreter.value in enabled_tools
|
document_id=str(uuid.uuid4()),
|
||||||
and self.agent_config.tool_choice == ToolChoice.required
|
content=a.content,
|
||||||
):
|
mime_type=a.mime_type,
|
||||||
return False
|
metadata={},
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return AgentTool.memory.value in enabled_tools
|
|
||||||
|
|
||||||
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
|
|
||||||
for t in self.agent_config.tools:
|
|
||||||
if t.type == AgentTool.memory.value:
|
|
||||||
return t
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _retrieve_context(
|
|
||||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
|
||||||
) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
|
|
||||||
bank_ids = []
|
|
||||||
|
|
||||||
memory = self._memory_tool_definition()
|
|
||||||
assert memory is not None, "Memory tool not configured"
|
|
||||||
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
|
|
||||||
|
|
||||||
if attachments:
|
|
||||||
bank_id = await self._ensure_memory_bank(session_id)
|
|
||||||
bank_ids.append(bank_id)
|
|
||||||
|
|
||||||
documents = [
|
|
||||||
MemoryBankDocument(
|
|
||||||
document_id=str(uuid.uuid4()),
|
|
||||||
content=a.content,
|
|
||||||
mime_type=a.mime_type,
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
for a in attachments
|
|
||||||
]
|
|
||||||
with tracing.span("insert_documents"):
|
|
||||||
await self.memory_api.insert_documents(bank_id, documents)
|
|
||||||
else:
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
|
||||||
if session_info.memory_bank_id:
|
|
||||||
bank_ids.append(session_info.memory_bank_id)
|
|
||||||
|
|
||||||
if not bank_ids:
|
|
||||||
# this can happen if the per-session memory bank is not yet populated
|
|
||||||
# (i.e., no prior turns uploaded an Attachment)
|
|
||||||
return None, []
|
|
||||||
|
|
||||||
query = await generate_rag_query(
|
|
||||||
memory.query_generator_config, messages, inference_api=self.inference_api
|
|
||||||
)
|
|
||||||
tasks = [
|
|
||||||
self.memory_api.query_documents(
|
|
||||||
bank_id=bank_id,
|
|
||||||
query=query,
|
|
||||||
params={
|
|
||||||
"max_chunks": 5,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
for bank_id in bank_ids
|
for a in data
|
||||||
]
|
]
|
||||||
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
await self.memory_api.insert_documents(
|
||||||
chunks = [c for r in results for c in r.chunks]
|
bank_id=bank_id,
|
||||||
scores = [s for r in results for s in r.scores]
|
documents=documents,
|
||||||
|
|
||||||
if not chunks:
|
|
||||||
return None, bank_ids
|
|
||||||
|
|
||||||
# sort by score
|
|
||||||
chunks, scores = zip(
|
|
||||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tokens = 0
|
|
||||||
picked = []
|
|
||||||
for c in chunks[: memory.max_chunks]:
|
|
||||||
tokens += c.token_count
|
|
||||||
if tokens > memory.max_tokens_in_context:
|
|
||||||
log.error(
|
|
||||||
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
|
||||||
)
|
|
||||||
break
|
|
||||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
|
||||||
|
|
||||||
return (
|
async def load_data_from_urls(urls: List[URL]) -> List[str]:
|
||||||
concat_interleaved_content(
|
data = []
|
||||||
[
|
for url in urls:
|
||||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
uri = url.uri
|
||||||
*picked,
|
if uri.startswith("file://"):
|
||||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
filepath = uri[len("file://") :]
|
||||||
]
|
with open(filepath, "r") as f:
|
||||||
),
|
data.append(f.read())
|
||||||
bank_ids,
|
elif uri.startswith("http"):
|
||||||
)
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(uri)
|
||||||
def _get_tools(self) -> List[ToolDefinition]:
|
resp = r.text
|
||||||
ret = []
|
data.append(resp)
|
||||||
for t in self.agent_config.tools:
|
return data
|
||||||
if isinstance(t, SearchToolDefinition):
|
|
||||||
ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search))
|
|
||||||
elif isinstance(t, WolframAlphaToolDefinition):
|
|
||||||
ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha))
|
|
||||||
elif isinstance(t, PhotogenToolDefinition):
|
|
||||||
ret.append(ToolDefinition(tool_name=BuiltinTool.photogen))
|
|
||||||
elif isinstance(t, CodeInterpreterToolDefinition):
|
|
||||||
ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter))
|
|
||||||
elif isinstance(t, FunctionCallToolDefinition):
|
|
||||||
ret.append(
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name=t.function_name,
|
|
||||||
description=t.description,
|
|
||||||
parameters=t.parameters,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
|
||||||
|
@ -839,7 +919,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
|
||||||
|
|
||||||
|
|
||||||
async def execute_tool_call_maybe(
|
async def execute_tool_call_maybe(
|
||||||
tools_dict: Dict[str, BaseTool], messages: List[CompletionMessage]
|
tool_runtime_api: ToolRuntime,
|
||||||
|
session_id: str,
|
||||||
|
messages: List[CompletionMessage],
|
||||||
|
toolgroup_args: Dict[str, Dict[str, Any]],
|
||||||
|
tool_to_group: Dict[str, str],
|
||||||
) -> List[ToolResponseMessage]:
|
) -> List[ToolResponseMessage]:
|
||||||
# While Tools.run interface takes a list of messages,
|
# While Tools.run interface takes a list of messages,
|
||||||
# All tools currently only run on a single message
|
# All tools currently only run on a single message
|
||||||
|
@ -851,11 +935,45 @@ async def execute_tool_call_maybe(
|
||||||
|
|
||||||
tool_call = message.tool_calls[0]
|
tool_call = message.tool_calls[0]
|
||||||
name = tool_call.tool_name
|
name = tool_call.tool_name
|
||||||
assert isinstance(name, BuiltinTool)
|
group_name = tool_to_group.get(name, None)
|
||||||
|
if group_name is None:
|
||||||
|
raise ValueError(f"Tool {name} not found in any tool group")
|
||||||
|
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||||
|
tool_call_args = tool_call.arguments
|
||||||
|
tool_call_args.update(toolgroup_args.get(group_name, {}))
|
||||||
|
if isinstance(name, BuiltinTool):
|
||||||
|
if name == BuiltinTool.brave_search:
|
||||||
|
name = WEB_SEARCH_TOOL
|
||||||
|
else:
|
||||||
|
name = name.value
|
||||||
|
|
||||||
name = name.value
|
result = await tool_runtime_api.invoke_tool(
|
||||||
|
tool_name=name,
|
||||||
|
args=dict(
|
||||||
|
session_id=session_id,
|
||||||
|
**tool_call_args,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
assert name in tools_dict, f"Tool {name} not found"
|
return [
|
||||||
tool = tools_dict[name]
|
ToolResponseMessage(
|
||||||
result_messages = await tool.run(messages)
|
call_id=tool_call.call_id,
|
||||||
return result_messages
|
tool_name=tool_call.tool_name,
|
||||||
|
content=result.content,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _interpret_content_as_attachment(
|
||||||
|
content: str,
|
||||||
|
) -> Optional[Attachment]:
|
||||||
|
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
||||||
|
if match:
|
||||||
|
snippet = match.group(1)
|
||||||
|
data = json.loads(snippet)
|
||||||
|
return Attachment(
|
||||||
|
url=URL(uri="file://" + data["filepath"]),
|
||||||
|
mime_type=data["mimetype"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
@ -19,17 +19,17 @@ from llama_stack.apis.agents import (
|
||||||
Agents,
|
Agents,
|
||||||
AgentSessionCreateResponse,
|
AgentSessionCreateResponse,
|
||||||
AgentStepResponse,
|
AgentStepResponse,
|
||||||
|
AgentToolGroup,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
Attachment,
|
Document,
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
|
|
||||||
from .agent_instance import ChatAgent
|
from .agent_instance import ChatAgent
|
||||||
|
@ -47,12 +47,16 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
memory_api: Memory,
|
memory_api: Memory,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
memory_banks_api: MemoryBanks,
|
memory_banks_api: MemoryBanks,
|
||||||
|
tool_runtime_api: ToolRuntime,
|
||||||
|
tool_groups_api: ToolGroups,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
self.memory_api = memory_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.memory_banks_api = memory_banks_api
|
self.memory_banks_api = memory_banks_api
|
||||||
|
self.tool_runtime_api = tool_runtime_api
|
||||||
|
self.tool_groups_api = tool_groups_api
|
||||||
|
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
self.tempdir = tempfile.mkdtemp()
|
self.tempdir = tempfile.mkdtemp()
|
||||||
|
@ -112,6 +116,8 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
memory_api=self.memory_api,
|
memory_api=self.memory_api,
|
||||||
memory_banks_api=self.memory_banks_api,
|
memory_banks_api=self.memory_banks_api,
|
||||||
|
tool_runtime_api=self.tool_runtime_api,
|
||||||
|
tool_groups_api=self.tool_groups_api,
|
||||||
persistence_store=(
|
persistence_store=(
|
||||||
self.persistence_store
|
self.persistence_store
|
||||||
if agent_config.enable_session_persistence
|
if agent_config.enable_session_persistence
|
||||||
|
@ -141,15 +147,17 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
attachments: Optional[List[Attachment]] = None,
|
toolgroups: Optional[List[AgentToolGroup]] = None,
|
||||||
|
documents: Optional[List[Document]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
attachments=attachments,
|
|
||||||
stream=True,
|
stream=True,
|
||||||
|
toolgroups=toolgroups,
|
||||||
|
documents=documents,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._create_agent_turn_streaming(request)
|
return self._create_agent_turn_streaming(request)
|
||||||
|
|
|
@ -8,13 +8,11 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Turn
|
from llama_stack.apis.agents import Turn
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
|
@ -1,93 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
|
||||||
Attachment,
|
|
||||||
BuiltinTool,
|
|
||||||
CompletionMessage,
|
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..tools.builtin import CodeInterpreterTool
|
|
||||||
|
|
||||||
|
|
||||||
class TestCodeInterpreter(unittest.IsolatedAsyncioTestCase):
|
|
||||||
async def test_matplotlib(self):
|
|
||||||
tool = CodeInterpreterTool()
|
|
||||||
code = """
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
x = np.array([1, 1])
|
|
||||||
y = np.array([0, 10])
|
|
||||||
|
|
||||||
plt.plot(x, y)
|
|
||||||
plt.title('x = 1')
|
|
||||||
plt.xlabel('x')
|
|
||||||
plt.ylabel('y')
|
|
||||||
plt.grid(True)
|
|
||||||
plt.axvline(x=1, color='r')
|
|
||||||
plt.show()
|
|
||||||
"""
|
|
||||||
message = CompletionMessage(
|
|
||||||
role="assistant",
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="call_id",
|
|
||||||
tool_name=BuiltinTool.code_interpreter,
|
|
||||||
arguments={"code": code},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
stop_reason=StopReason.end_of_message,
|
|
||||||
)
|
|
||||||
ret = await tool.run([message])
|
|
||||||
|
|
||||||
self.assertEqual(len(ret), 1)
|
|
||||||
|
|
||||||
output = ret[0].content
|
|
||||||
self.assertIsInstance(output, Attachment)
|
|
||||||
self.assertEqual(output.mime_type, "image/png")
|
|
||||||
|
|
||||||
async def test_path_unlink(self):
|
|
||||||
tool = CodeInterpreterTool()
|
|
||||||
code = """
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
dpath = Path(os.environ["MPLCONFIGDIR"])
|
|
||||||
with open(dpath / "test", "w") as f:
|
|
||||||
f.write("hello")
|
|
||||||
|
|
||||||
Path(dpath / "test").unlink()
|
|
||||||
print("_OK_")
|
|
||||||
"""
|
|
||||||
message = CompletionMessage(
|
|
||||||
role="assistant",
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
ToolCall(
|
|
||||||
call_id="call_id",
|
|
||||||
tool_name=BuiltinTool.code_interpreter,
|
|
||||||
arguments={"code": code},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
stop_reason=StopReason.end_of_message,
|
|
||||||
)
|
|
||||||
ret = await tool.run([message])
|
|
||||||
|
|
||||||
self.assertEqual(len(ret), 1)
|
|
||||||
|
|
||||||
output = ret[0].content
|
|
||||||
self.assertTrue("_OK_" in output)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
|
@ -4,21 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
from typing import AsyncIterator, List, Optional, Union
|
from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
|
AgentToolGroupWithArgs,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
|
StepType,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
@ -27,13 +32,24 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.memory import MemoryBank
|
from llama_stack.apis.memory import MemoryBank
|
||||||
|
from llama_stack.apis.memory_banks import BankParams, VectorMemoryBank
|
||||||
from llama_stack.apis.safety import RunShieldResponse
|
from llama_stack.apis.safety import RunShieldResponse
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
from ..agents import (
|
Tool,
|
||||||
AGENT_INSTANCES_BY_ID,
|
ToolDef,
|
||||||
MetaReferenceAgentsImpl,
|
ToolGroup,
|
||||||
MetaReferenceInferenceConfig,
|
ToolHost,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||||
|
MEMORY_QUERY_TOOL,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.inline.agents.meta_reference.agents import (
|
||||||
|
MetaReferenceAgentsImpl,
|
||||||
|
MetaReferenceAgentsImplConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
class MockInferenceAPI:
|
class MockInferenceAPI:
|
||||||
|
@ -48,10 +64,10 @@ class MockInferenceAPI:
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncIterator[
|
) -> Union[
|
||||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||||
]:
|
]:
|
||||||
if stream:
|
async def stream_response():
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type="start",
|
event_type="start",
|
||||||
|
@ -65,19 +81,7 @@ class MockInferenceAPI:
|
||||||
delta="AI is a fascinating field...",
|
delta="AI is a fascinating field...",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# yield ChatCompletionResponseStreamChunk(
|
|
||||||
# event=ChatCompletionResponseEvent(
|
|
||||||
# event_type="progress",
|
|
||||||
# delta=ToolCallDelta(
|
|
||||||
# content=ToolCall(
|
|
||||||
# call_id="123",
|
|
||||||
# tool_name=BuiltinTool.brave_search.value,
|
|
||||||
# arguments={"query": "AI history"},
|
|
||||||
# ),
|
|
||||||
# parse_status="success",
|
|
||||||
# ),
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type="complete",
|
event_type="complete",
|
||||||
|
@ -85,12 +89,17 @@ class MockInferenceAPI:
|
||||||
stop_reason="end_of_turn",
|
stop_reason="end_of_turn",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return stream_response()
|
||||||
else:
|
else:
|
||||||
yield ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
role="assistant", content="Mock response", stop_reason="end_of_turn"
|
role="assistant",
|
||||||
|
content="Mock response",
|
||||||
|
stop_reason="end_of_turn",
|
||||||
),
|
),
|
||||||
logprobs=[0.1, 0.2, 0.3] if logprobs else None,
|
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -165,6 +174,98 @@ class MockMemoryAPI:
|
||||||
self.documents[bank_id].pop(doc_id, None)
|
self.documents[bank_id].pop(doc_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
class MockToolGroupsAPI:
|
||||||
|
async def register_tool_group(
|
||||||
|
self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||||
|
return ToolGroup(
|
||||||
|
identifier=toolgroup_id,
|
||||||
|
provider_resource_id=toolgroup_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||||
|
if tool_group_id == MEMORY_TOOLGROUP:
|
||||||
|
return [
|
||||||
|
Tool(
|
||||||
|
identifier=MEMORY_QUERY_TOOL,
|
||||||
|
provider_resource_id=MEMORY_QUERY_TOOL,
|
||||||
|
toolgroup_id=MEMORY_TOOLGROUP,
|
||||||
|
tool_host=ToolHost.client,
|
||||||
|
description="Mock tool",
|
||||||
|
provider_id="builtin::memory",
|
||||||
|
parameters=[],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
|
||||||
|
return [
|
||||||
|
Tool(
|
||||||
|
identifier="code_interpreter",
|
||||||
|
provider_resource_id="code_interpreter",
|
||||||
|
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
|
||||||
|
tool_host=ToolHost.client,
|
||||||
|
description="Mock tool",
|
||||||
|
provider_id="builtin::code_interpreter",
|
||||||
|
parameters=[],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
|
return Tool(
|
||||||
|
identifier=tool_name,
|
||||||
|
provider_resource_id=tool_name,
|
||||||
|
toolgroup_id="mock_group",
|
||||||
|
tool_host=ToolHost.client,
|
||||||
|
description="Mock tool",
|
||||||
|
provider_id="mock_provider",
|
||||||
|
parameters=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MockToolRuntimeAPI:
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
|
||||||
|
return ToolInvocationResult(content={"result": "Mock tool result"})
|
||||||
|
|
||||||
|
|
||||||
|
class MockMemoryBanksAPI:
|
||||||
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def register_memory_bank(
|
||||||
|
self,
|
||||||
|
memory_bank_id: str,
|
||||||
|
params: BankParams,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
provider_memory_bank_id: Optional[str] = None,
|
||||||
|
) -> MemoryBank:
|
||||||
|
return VectorMemoryBank(
|
||||||
|
identifier=memory_bank_id,
|
||||||
|
provider_resource_id=provider_memory_bank_id or memory_bank_id,
|
||||||
|
embedding_model="mock_model",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_inference_api():
|
def mock_inference_api():
|
||||||
return MockInferenceAPI()
|
return MockInferenceAPI()
|
||||||
|
@ -181,64 +282,107 @@ def mock_memory_api():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
|
def mock_tool_groups_api():
|
||||||
|
return MockToolGroupsAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool_runtime_api():
|
||||||
|
return MockToolRuntimeAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_memory_banks_api():
|
||||||
|
return MockMemoryBanksAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def get_agents_impl(
|
||||||
|
mock_inference_api,
|
||||||
|
mock_safety_api,
|
||||||
|
mock_memory_api,
|
||||||
|
mock_memory_banks_api,
|
||||||
|
mock_tool_runtime_api,
|
||||||
|
mock_tool_groups_api,
|
||||||
|
):
|
||||||
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
impl = MetaReferenceAgentsImpl(
|
impl = MetaReferenceAgentsImpl(
|
||||||
config=MetaReferenceInferenceConfig(),
|
config=MetaReferenceAgentsImplConfig(
|
||||||
|
persistence_store=SqliteKVStoreConfig(
|
||||||
|
db_name=sqlite_file.name,
|
||||||
|
),
|
||||||
|
),
|
||||||
inference_api=mock_inference_api,
|
inference_api=mock_inference_api,
|
||||||
safety_api=mock_safety_api,
|
safety_api=mock_safety_api,
|
||||||
memory_api=mock_memory_api,
|
memory_api=mock_memory_api,
|
||||||
|
memory_banks_api=mock_memory_banks_api,
|
||||||
|
tool_runtime_api=mock_tool_runtime_api,
|
||||||
|
tool_groups_api=mock_tool_groups_api,
|
||||||
)
|
)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def get_chat_agent(get_agents_impl):
|
||||||
|
impl = await get_agents_impl
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model="test_model",
|
model="test_model",
|
||||||
instructions="You are a helpful assistant.",
|
instructions="You are a helpful assistant.",
|
||||||
sampling_params=SamplingParams(),
|
toolgroups=[],
|
||||||
tools=[
|
|
||||||
# SearchToolDefinition(
|
|
||||||
# name="brave_search",
|
|
||||||
# api_key="test_key",
|
|
||||||
# ),
|
|
||||||
],
|
|
||||||
tool_choice=ToolChoice.auto,
|
tool_choice=ToolChoice.auto,
|
||||||
enable_session_persistence=False,
|
enable_session_persistence=False,
|
||||||
input_shields=[],
|
input_shields=["test_shield"],
|
||||||
output_shields=[],
|
|
||||||
)
|
)
|
||||||
response = await impl.create_agent(agent_config)
|
response = await impl.create_agent(agent_config)
|
||||||
agent = AGENT_INSTANCES_BY_ID[response.agent_id]
|
return await impl.get_agent(response.agent_id)
|
||||||
return agent
|
|
||||||
|
|
||||||
|
MEMORY_TOOLGROUP = "builtin::memory"
|
||||||
|
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def get_chat_agent_with_tools(get_agents_impl, request):
|
||||||
|
impl = await get_agents_impl
|
||||||
|
toolgroups = request.param
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model="test_model",
|
||||||
|
instructions="You are a helpful assistant.",
|
||||||
|
toolgroups=toolgroups,
|
||||||
|
tool_choice=ToolChoice.auto,
|
||||||
|
enable_session_persistence=False,
|
||||||
|
input_shields=["test_shield"],
|
||||||
|
)
|
||||||
|
response = await impl.create_agent(agent_config)
|
||||||
|
return await impl.get_agent(response.agent_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_agent_create_session(chat_agent):
|
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
|
||||||
session = chat_agent.create_session("Test Session")
|
chat_agent = await get_chat_agent
|
||||||
assert session.session_name == "Test Session"
|
session_id = await chat_agent.create_session("Test Session")
|
||||||
assert session.turns == []
|
|
||||||
assert session.session_id in chat_agent.sessions
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_agent_create_and_execute_turn(chat_agent):
|
|
||||||
session = chat_agent.create_session("Test Session")
|
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id="random",
|
agent_id=chat_agent.agent_id,
|
||||||
session_id=session.session_id,
|
session_id=session_id,
|
||||||
messages=[UserMessage(content="Hello")],
|
messages=[UserMessage(content="Hello")],
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
responses = []
|
responses = []
|
||||||
async for response in chat_agent.create_and_execute_turn(request):
|
async for response in chat_agent.create_and_execute_turn(request):
|
||||||
responses.append(response)
|
responses.append(response)
|
||||||
|
|
||||||
print(responses)
|
|
||||||
assert len(responses) > 0
|
assert len(responses) > 0
|
||||||
assert len(responses) == 4 # TurnStart, StepStart, StepComplete, TurnComplete
|
assert (
|
||||||
|
len(responses) == 7
|
||||||
|
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
|
||||||
assert responses[0].event.payload.turn_id is not None
|
assert responses[0].event.payload.turn_id is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_multiple_shields_wrapper(chat_agent):
|
async def test_run_multiple_shields_wrapper(get_chat_agent):
|
||||||
|
chat_agent = await get_chat_agent
|
||||||
messages = [UserMessage(content="Test message")]
|
messages = [UserMessage(content="Test message")]
|
||||||
shields = ["test_shield"]
|
shields = ["test_shield"]
|
||||||
|
|
||||||
|
@ -254,69 +398,95 @@ async def test_run_multiple_shields_wrapper(chat_agent):
|
||||||
|
|
||||||
assert len(responses) == 2 # StepStart, StepComplete
|
assert len(responses) == 2 # StepStart, StepComplete
|
||||||
assert responses[0].event.payload.step_type.value == "shield_call"
|
assert responses[0].event.payload.step_type.value == "shield_call"
|
||||||
assert not responses[1].event.payload.step_details.response.is_violation
|
assert not responses[1].event.payload.step_details.violation
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skip(reason="Not yet implemented; need to mock out tool execution easily")
|
async def test_chat_agent_complex_turn(get_chat_agent):
|
||||||
async def test_chat_agent_complex_turn(chat_agent):
|
chat_agent = await get_chat_agent
|
||||||
# Setup
|
session_id = await chat_agent.create_session("Test Session")
|
||||||
session = chat_agent.create_session("Test Session")
|
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id="random",
|
agent_id=chat_agent.agent_id,
|
||||||
session_id=session.session_id,
|
session_id=session_id,
|
||||||
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
|
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the turn
|
|
||||||
responses = []
|
responses = []
|
||||||
async for response in chat_agent.create_and_execute_turn(request):
|
async for response in chat_agent.create_and_execute_turn(request):
|
||||||
responses.append(response)
|
responses.append(response)
|
||||||
|
|
||||||
# Assertions
|
|
||||||
assert len(responses) > 0
|
assert len(responses) > 0
|
||||||
|
|
||||||
# Check for the presence of different step types
|
|
||||||
step_types = [
|
step_types = [
|
||||||
response.event.payload.step_type
|
response.event.payload.step_type
|
||||||
for response in responses
|
for response in responses
|
||||||
if hasattr(response.event.payload, "step_type")
|
if hasattr(response.event.payload, "step_type")
|
||||||
]
|
]
|
||||||
|
|
||||||
assert "shield_call" in step_types, "Shield call step is missing"
|
assert StepType.shield_call in step_types, "Shield call step is missing"
|
||||||
assert "inference" in step_types, "Inference step is missing"
|
assert StepType.inference in step_types, "Inference step is missing"
|
||||||
assert "tool_execution" in step_types, "Tool execution step is missing"
|
|
||||||
|
|
||||||
# Check for the presence of start and complete events
|
|
||||||
event_types = [
|
event_types = [
|
||||||
response.event.payload.event_type
|
response.event.payload.event_type
|
||||||
for response in responses
|
for response in responses
|
||||||
if hasattr(response.event.payload, "event_type")
|
if hasattr(response.event.payload, "event_type")
|
||||||
]
|
]
|
||||||
assert "start" in event_types, "Start event is missing"
|
assert "turn_start" in event_types, "Start event is missing"
|
||||||
assert "complete" in event_types, "Complete event is missing"
|
assert "turn_complete" in event_types, "Complete event is missing"
|
||||||
|
|
||||||
# Check for the presence of tool call
|
|
||||||
tool_calls = [
|
|
||||||
response.event.payload.tool_call
|
|
||||||
for response in responses
|
|
||||||
if hasattr(response.event.payload, "tool_call")
|
|
||||||
]
|
|
||||||
assert any(
|
|
||||||
tool_call
|
|
||||||
for tool_call in tool_calls
|
|
||||||
if tool_call and tool_call.content.get("name") == "memory"
|
|
||||||
), "Memory tool call is missing"
|
|
||||||
|
|
||||||
# Check for the final turn complete event
|
|
||||||
assert any(
|
assert any(
|
||||||
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||||
for response in responses
|
for response in responses
|
||||||
), "Turn complete event is missing"
|
), "Turn complete event is missing"
|
||||||
|
turn_complete_payload = next(
|
||||||
|
response.event.payload
|
||||||
|
for response in responses
|
||||||
|
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
|
||||||
|
)
|
||||||
|
turn = turn_complete_payload.turn
|
||||||
|
assert turn.input_messages == request.messages, "Input messages do not match"
|
||||||
|
|
||||||
# Verify the turn was added to the session
|
|
||||||
assert len(session.turns) == 1, "Turn was not added to the session"
|
@pytest.mark.asyncio
|
||||||
assert (
|
@pytest.mark.parametrize(
|
||||||
session.turns[0].input_messages == request.messages
|
"toolgroups, expected_memory, expected_code_interpreter",
|
||||||
), "Input messages do not match"
|
[
|
||||||
|
([], False, False), # no tools
|
||||||
|
([MEMORY_TOOLGROUP], True, False), # memory only
|
||||||
|
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
|
||||||
|
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_chat_agent_tools(
|
||||||
|
get_agents_impl, toolgroups, expected_memory, expected_code_interpreter
|
||||||
|
):
|
||||||
|
impl = await get_agents_impl
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model="test_model",
|
||||||
|
instructions="You are a helpful assistant.",
|
||||||
|
toolgroups=toolgroups,
|
||||||
|
tool_choice=ToolChoice.auto,
|
||||||
|
enable_session_persistence=False,
|
||||||
|
input_shields=["test_shield"],
|
||||||
|
)
|
||||||
|
response = await impl.create_agent(agent_config)
|
||||||
|
chat_agent = await impl.get_agent(response.agent_id)
|
||||||
|
|
||||||
|
tool_defs, _ = await chat_agent._get_tool_defs()
|
||||||
|
if expected_memory:
|
||||||
|
assert MEMORY_QUERY_TOOL in tool_defs
|
||||||
|
if expected_code_interpreter:
|
||||||
|
assert BuiltinTool.code_interpreter in tool_defs
|
||||||
|
if expected_memory and expected_code_interpreter:
|
||||||
|
# override the tools for turn
|
||||||
|
new_tool_defs, _ = await chat_agent._get_tool_defs(
|
||||||
|
toolgroups_for_turn=[
|
||||||
|
AgentToolGroupWithArgs(
|
||||||
|
name=MEMORY_TOOLGROUP,
|
||||||
|
args={"memory_banks": ["test_memory_bank"]},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert MEMORY_QUERY_TOOL in new_tool_defs
|
||||||
|
assert BuiltinTool.code_interpreter not in new_tool_defs
|
||||||
|
|
|
@ -1,20 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def get_name(self) -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def run(self, messages: List[Message]) -> List[Message]:
|
|
||||||
raise NotImplementedError
|
|
|
@ -1,396 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from abc import abstractmethod
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from .ipython_tool.code_execution import (
|
|
||||||
CodeExecutionContext,
|
|
||||||
CodeExecutionRequest,
|
|
||||||
CodeExecutor,
|
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
|
||||||
|
|
||||||
from .base import BaseTool
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def interpret_content_as_attachment(content: str) -> Optional[Attachment]:
|
|
||||||
match = re.search(TOOLS_ATTACHMENT_KEY_REGEX, content)
|
|
||||||
if match:
|
|
||||||
snippet = match.group(1)
|
|
||||||
data = json.loads(snippet)
|
|
||||||
return Attachment(
|
|
||||||
url=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class SingleMessageBuiltinTool(BaseTool):
|
|
||||||
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
|
||||||
assert len(messages) == 1, f"Expected single message, got {len(messages)}"
|
|
||||||
|
|
||||||
message = messages[0]
|
|
||||||
assert len(message.tool_calls) == 1, "Expected a single tool call"
|
|
||||||
|
|
||||||
tool_call = messages[0].tool_calls[0]
|
|
||||||
|
|
||||||
query = tool_call.arguments["query"]
|
|
||||||
response: str = await self.run_impl(query)
|
|
||||||
|
|
||||||
message = ToolResponseMessage(
|
|
||||||
call_id=tool_call.call_id,
|
|
||||||
tool_name=tool_call.tool_name,
|
|
||||||
content=response,
|
|
||||||
)
|
|
||||||
return [message]
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def run_impl(self, query: str) -> str:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class PhotogenTool(SingleMessageBuiltinTool):
|
|
||||||
def __init__(self, dump_dir: str) -> None:
|
|
||||||
self.dump_dir = dump_dir
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return BuiltinTool.photogen.value
|
|
||||||
|
|
||||||
async def run_impl(self, query: str) -> str:
|
|
||||||
"""
|
|
||||||
Implement this to give the model an ability to generate images.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
info = {
|
|
||||||
"filepath": str(image_filepath),
|
|
||||||
"mimetype": "image/png",
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class SearchTool(SingleMessageBuiltinTool):
|
|
||||||
def __init__(self, engine: SearchEngineType, api_key: str, **kwargs) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
self.engine_type = engine
|
|
||||||
if engine == SearchEngineType.bing:
|
|
||||||
self.engine = BingSearch(api_key, **kwargs)
|
|
||||||
elif engine == SearchEngineType.brave:
|
|
||||||
self.engine = BraveSearch(api_key, **kwargs)
|
|
||||||
elif engine == SearchEngineType.tavily:
|
|
||||||
self.engine = TavilySearch(api_key, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown search engine: {engine}")
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return BuiltinTool.brave_search.value
|
|
||||||
|
|
||||||
async def run_impl(self, query: str) -> str:
|
|
||||||
return await self.engine.search(query)
|
|
||||||
|
|
||||||
|
|
||||||
class BingSearch:
|
|
||||||
def __init__(self, api_key: str, top_k: int = 3, **kwargs) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
self.top_k = top_k
|
|
||||||
|
|
||||||
async def search(self, query: str) -> str:
|
|
||||||
url = "https://api.bing.microsoft.com/v7.0/search"
|
|
||||||
headers = {
|
|
||||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
|
||||||
}
|
|
||||||
params = {
|
|
||||||
"count": self.top_k,
|
|
||||||
"textDecorations": True,
|
|
||||||
"textFormat": "HTML",
|
|
||||||
"q": query,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.get(url=url, params=params, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
clean = self._clean_response(response.json())
|
|
||||||
return json.dumps(clean)
|
|
||||||
|
|
||||||
def _clean_response(self, search_response):
|
|
||||||
clean_response = []
|
|
||||||
query = search_response["queryContext"]["originalQuery"]
|
|
||||||
if "webPages" in search_response:
|
|
||||||
pages = search_response["webPages"]["value"]
|
|
||||||
for p in pages:
|
|
||||||
selected_keys = {"name", "url", "snippet"}
|
|
||||||
clean_response.append(
|
|
||||||
{k: v for k, v in p.items() if k in selected_keys}
|
|
||||||
)
|
|
||||||
if "news" in search_response:
|
|
||||||
clean_news = []
|
|
||||||
news = search_response["news"]["value"]
|
|
||||||
for n in news:
|
|
||||||
selected_keys = {"name", "url", "description"}
|
|
||||||
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
|
||||||
|
|
||||||
clean_response.append(clean_news)
|
|
||||||
|
|
||||||
return {"query": query, "top_k": clean_response}
|
|
||||||
|
|
||||||
|
|
||||||
class BraveSearch:
|
|
||||||
def __init__(self, api_key: str) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
|
|
||||||
async def search(self, query: str) -> str:
|
|
||||||
url = "https://api.search.brave.com/res/v1/web/search"
|
|
||||||
headers = {
|
|
||||||
"X-Subscription-Token": self.api_key,
|
|
||||||
"Accept-Encoding": "gzip",
|
|
||||||
"Accept": "application/json",
|
|
||||||
}
|
|
||||||
payload = {"q": query}
|
|
||||||
response = requests.get(url=url, params=payload, headers=headers)
|
|
||||||
return json.dumps(self._clean_brave_response(response.json()))
|
|
||||||
|
|
||||||
def _clean_brave_response(self, search_response, top_k=3):
|
|
||||||
query = None
|
|
||||||
clean_response = []
|
|
||||||
if "query" in search_response:
|
|
||||||
if "original" in search_response["query"]:
|
|
||||||
query = search_response["query"]["original"]
|
|
||||||
if "mixed" in search_response:
|
|
||||||
mixed_results = search_response["mixed"]
|
|
||||||
for m in mixed_results["main"][:top_k]:
|
|
||||||
r_type = m["type"]
|
|
||||||
results = search_response[r_type]["results"]
|
|
||||||
if r_type == "web":
|
|
||||||
# For web data - add a single output from the search
|
|
||||||
idx = m["index"]
|
|
||||||
selected_keys = [
|
|
||||||
"type",
|
|
||||||
"title",
|
|
||||||
"url",
|
|
||||||
"description",
|
|
||||||
"date",
|
|
||||||
"extra_snippets",
|
|
||||||
]
|
|
||||||
cleaned = {
|
|
||||||
k: v for k, v in results[idx].items() if k in selected_keys
|
|
||||||
}
|
|
||||||
elif r_type == "faq":
|
|
||||||
# For faw data - take a list of all the questions & answers
|
|
||||||
selected_keys = ["type", "question", "answer", "title", "url"]
|
|
||||||
cleaned = []
|
|
||||||
for q in results:
|
|
||||||
cleaned.append(
|
|
||||||
{k: v for k, v in q.items() if k in selected_keys}
|
|
||||||
)
|
|
||||||
elif r_type == "infobox":
|
|
||||||
idx = m["index"]
|
|
||||||
selected_keys = [
|
|
||||||
"type",
|
|
||||||
"title",
|
|
||||||
"url",
|
|
||||||
"description",
|
|
||||||
"long_desc",
|
|
||||||
]
|
|
||||||
cleaned = {
|
|
||||||
k: v for k, v in results[idx].items() if k in selected_keys
|
|
||||||
}
|
|
||||||
elif r_type == "videos":
|
|
||||||
selected_keys = [
|
|
||||||
"type",
|
|
||||||
"url",
|
|
||||||
"title",
|
|
||||||
"description",
|
|
||||||
"date",
|
|
||||||
]
|
|
||||||
cleaned = []
|
|
||||||
for q in results:
|
|
||||||
cleaned.append(
|
|
||||||
{k: v for k, v in q.items() if k in selected_keys}
|
|
||||||
)
|
|
||||||
elif r_type == "locations":
|
|
||||||
# For faw data - take a list of all the questions & answers
|
|
||||||
selected_keys = [
|
|
||||||
"type",
|
|
||||||
"title",
|
|
||||||
"url",
|
|
||||||
"description",
|
|
||||||
"coordinates",
|
|
||||||
"postal_address",
|
|
||||||
"contact",
|
|
||||||
"rating",
|
|
||||||
"distance",
|
|
||||||
"zoom_level",
|
|
||||||
]
|
|
||||||
cleaned = []
|
|
||||||
for q in results:
|
|
||||||
cleaned.append(
|
|
||||||
{k: v for k, v in q.items() if k in selected_keys}
|
|
||||||
)
|
|
||||||
elif r_type == "news":
|
|
||||||
# For faw data - take a list of all the questions & answers
|
|
||||||
selected_keys = [
|
|
||||||
"type",
|
|
||||||
"title",
|
|
||||||
"url",
|
|
||||||
"description",
|
|
||||||
]
|
|
||||||
cleaned = []
|
|
||||||
for q in results:
|
|
||||||
cleaned.append(
|
|
||||||
{k: v for k, v in q.items() if k in selected_keys}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cleaned = []
|
|
||||||
|
|
||||||
clean_response.append(cleaned)
|
|
||||||
|
|
||||||
return {"query": query, "top_k": clean_response}
|
|
||||||
|
|
||||||
|
|
||||||
class TavilySearch:
|
|
||||||
def __init__(self, api_key: str) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
|
|
||||||
async def search(self, query: str) -> str:
|
|
||||||
response = requests.post(
|
|
||||||
"https://api.tavily.com/search",
|
|
||||||
json={"api_key": self.api_key, "query": query},
|
|
||||||
)
|
|
||||||
return json.dumps(self._clean_tavily_response(response.json()))
|
|
||||||
|
|
||||||
def _clean_tavily_response(self, search_response, top_k=3):
|
|
||||||
return {"query": search_response["query"], "top_k": search_response["results"]}
|
|
||||||
|
|
||||||
|
|
||||||
class WolframAlphaTool(SingleMessageBuiltinTool):
|
|
||||||
def __init__(self, api_key: str) -> None:
|
|
||||||
self.api_key = api_key
|
|
||||||
self.url = "https://api.wolframalpha.com/v2/query"
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return BuiltinTool.wolfram_alpha.value
|
|
||||||
|
|
||||||
async def run_impl(self, query: str) -> str:
|
|
||||||
params = {
|
|
||||||
"input": query,
|
|
||||||
"appid": self.api_key,
|
|
||||||
"format": "plaintext",
|
|
||||||
"output": "json",
|
|
||||||
}
|
|
||||||
response = requests.get(
|
|
||||||
self.url,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
||||||
return json.dumps(self._clean_wolfram_alpha_response(response.json()))
|
|
||||||
|
|
||||||
def _clean_wolfram_alpha_response(self, wa_response):
|
|
||||||
remove = {
|
|
||||||
"queryresult": [
|
|
||||||
"datatypes",
|
|
||||||
"error",
|
|
||||||
"timedout",
|
|
||||||
"timedoutpods",
|
|
||||||
"numpods",
|
|
||||||
"timing",
|
|
||||||
"parsetiming",
|
|
||||||
"parsetimedout",
|
|
||||||
"recalculate",
|
|
||||||
"id",
|
|
||||||
"host",
|
|
||||||
"server",
|
|
||||||
"related",
|
|
||||||
"version",
|
|
||||||
{
|
|
||||||
"pods": [
|
|
||||||
"scanner",
|
|
||||||
"id",
|
|
||||||
"error",
|
|
||||||
"expressiontypes",
|
|
||||||
"states",
|
|
||||||
"infos",
|
|
||||||
"position",
|
|
||||||
"numsubpods",
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"assumptions",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
for main_key in remove:
|
|
||||||
for key_to_remove in remove[main_key]:
|
|
||||||
try:
|
|
||||||
if key_to_remove == "assumptions":
|
|
||||||
if "assumptions" in wa_response[main_key]:
|
|
||||||
del wa_response[main_key][key_to_remove]
|
|
||||||
if isinstance(key_to_remove, dict):
|
|
||||||
for sub_key in key_to_remove:
|
|
||||||
if sub_key == "pods":
|
|
||||||
for i in range(len(wa_response[main_key][sub_key])):
|
|
||||||
if (
|
|
||||||
wa_response[main_key][sub_key][i]["title"]
|
|
||||||
== "Result"
|
|
||||||
):
|
|
||||||
del wa_response[main_key][sub_key][i + 1 :]
|
|
||||||
break
|
|
||||||
sub_items = wa_response[main_key][sub_key]
|
|
||||||
for i in range(len(sub_items)):
|
|
||||||
for sub_key_to_remove in key_to_remove[sub_key]:
|
|
||||||
if sub_key_to_remove in sub_items[i]:
|
|
||||||
del sub_items[i][sub_key_to_remove]
|
|
||||||
elif key_to_remove in wa_response[main_key]:
|
|
||||||
del wa_response[main_key][key_to_remove]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return wa_response
|
|
||||||
|
|
||||||
|
|
||||||
class CodeInterpreterTool(BaseTool):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
ctx = CodeExecutionContext(
|
|
||||||
matplotlib_dump_dir=tempfile.mkdtemp(),
|
|
||||||
)
|
|
||||||
self.code_executor = CodeExecutor(ctx)
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return BuiltinTool.code_interpreter.value
|
|
||||||
|
|
||||||
async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
|
|
||||||
message = messages[0]
|
|
||||||
assert len(message.tool_calls) == 1, "Expected a single tool call"
|
|
||||||
|
|
||||||
tool_call = messages[0].tool_calls[0]
|
|
||||||
script = tool_call.arguments["code"]
|
|
||||||
|
|
||||||
req = CodeExecutionRequest(scripts=[script])
|
|
||||||
res = self.code_executor.execute(req)
|
|
||||||
|
|
||||||
pieces = [res["process_status"]]
|
|
||||||
for out_type in ["stdout", "stderr"]:
|
|
||||||
res_out = res[out_type]
|
|
||||||
if res_out != "":
|
|
||||||
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
|
||||||
if out_type == "stderr":
|
|
||||||
log.error(f"ipython tool error: ↓\n{res_out}")
|
|
||||||
|
|
||||||
message = ToolResponseMessage(
|
|
||||||
call_id=tool_call.call_id,
|
|
||||||
tool_name=tool_call.tool_name,
|
|
||||||
content="\n".join(pieces),
|
|
||||||
)
|
|
||||||
return [message]
|
|
|
@ -1,42 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
|
||||||
from llama_stack.apis.safety import Safety
|
|
||||||
|
|
||||||
from ..safety import ShieldRunnerMixin
|
|
||||||
from .builtin import BaseTool
|
|
||||||
|
|
||||||
|
|
||||||
class SafeTool(BaseTool, ShieldRunnerMixin):
|
|
||||||
"""A tool that makes other tools safety enabled"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tool: BaseTool,
|
|
||||||
safety_api: Safety,
|
|
||||||
input_shields: List[str] = None,
|
|
||||||
output_shields: List[str] = None,
|
|
||||||
):
|
|
||||||
self._tool = tool
|
|
||||||
ShieldRunnerMixin.__init__(
|
|
||||||
self, safety_api, input_shields=input_shields, output_shields=output_shields
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return self._tool.get_name()
|
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> List[Message]:
|
|
||||||
if self.input_shields:
|
|
||||||
await self.run_multiple_shields(messages, self.input_shields)
|
|
||||||
# run the underlying tool
|
|
||||||
res = await self._tool.run(messages)
|
|
||||||
if self.output_shields:
|
|
||||||
await self.run_multiple_shields(messages, self.output_shields)
|
|
||||||
|
|
||||||
return res
|
|
|
@ -5,5 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
|
KVStoreConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
)
|
||||||
|
|
||||||
class LocalFSDatasetIOConfig(BaseModel): ...
|
|
||||||
|
class LocalFSDatasetIOConfig(BaseModel):
|
||||||
|
kvstore: KVStoreConfig = SqliteKVStoreConfig(
|
||||||
|
db_path=(RUNTIME_BASE_DIR / "localfs_datasetio.db").as_posix()
|
||||||
|
) # Uses SQLite config specific to localfs storage
|
||||||
|
|
|
@ -18,10 +18,14 @@ from llama_stack.apis.datasets import Dataset
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from .config import LocalFSDatasetIOConfig
|
from .config import LocalFSDatasetIOConfig
|
||||||
|
|
||||||
|
|
||||||
|
DATASETS_PREFIX = "localfs_datasets:"
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(ABC):
|
class BaseDataset(ABC):
|
||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
@ -86,8 +90,22 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
self.config = config
|
self.config = config
|
||||||
# local registry for keeping track of datasets within the provider
|
# local registry for keeping track of datasets within the provider
|
||||||
self.dataset_infos = {}
|
self.dataset_infos = {}
|
||||||
|
self.kvstore = None
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None:
|
||||||
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
# Load existing datasets from kvstore
|
||||||
|
start_key = DATASETS_PREFIX
|
||||||
|
end_key = f"{DATASETS_PREFIX}\xff"
|
||||||
|
stored_datasets = await self.kvstore.range(start_key, end_key)
|
||||||
|
|
||||||
|
for dataset in stored_datasets:
|
||||||
|
dataset = Dataset.model_validate_json(dataset)
|
||||||
|
dataset_impl = PandasDataframeDataset(dataset)
|
||||||
|
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
||||||
|
dataset_def=dataset,
|
||||||
|
dataset_impl=dataset_impl,
|
||||||
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None: ...
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
@ -95,6 +113,12 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# Store in kvstore
|
||||||
|
key = f"{DATASETS_PREFIX}{dataset.identifier}"
|
||||||
|
await self.kvstore.set(
|
||||||
|
key=key,
|
||||||
|
value=dataset.json(),
|
||||||
|
)
|
||||||
dataset_impl = PandasDataframeDataset(dataset)
|
dataset_impl = PandasDataframeDataset(dataset)
|
||||||
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
self.dataset_infos[dataset.identifier] = DatasetInfo(
|
||||||
dataset_def=dataset,
|
dataset_def=dataset,
|
||||||
|
@ -102,6 +126,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
|
key = f"{DATASETS_PREFIX}{dataset_id}"
|
||||||
|
await self.kvstore.delete(key=key)
|
||||||
del self.dataset_infos[dataset_id]
|
del self.dataset_infos[dataset_id]
|
||||||
|
|
||||||
async def get_rows_paginated(
|
async def get_rows_paginated(
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import (
|
from llama_models.llama3.api.datatypes import (
|
||||||
|
@ -37,7 +36,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
@ -262,7 +260,7 @@ class MetaReferenceInferenceImpl(
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import SentenceTransformersInferenceConfig
|
from .config import SentenceTransformersInferenceConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -67,7 +68,7 @@ class SentenceTransformersInferenceImpl(
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
|
@ -10,10 +10,8 @@ import uuid
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||||
|
@ -36,7 +34,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
|
@ -50,7 +47,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import VLLMConfig
|
from .config import VLLMConfig
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,7 +63,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
log.info("Initializing vLLM inference adapter")
|
log.info("Initializing vLLM inference provider.")
|
||||||
|
|
||||||
# Disable usage stats reporting. This would be a surprising thing for most
|
# Disable usage stats reporting. This would be a surprising thing for most
|
||||||
# people to find out was on by default.
|
# people to find out was on by default.
|
||||||
|
@ -95,15 +91,36 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
"""Shutdown the vLLM inference adapter."""
|
"""Shut down the vLLM inference adapter."""
|
||||||
log.info("Shutting down vLLM inference adapter")
|
log.info("Shutting down vLLM inference provider.")
|
||||||
if self.engine:
|
if self.engine:
|
||||||
self.engine.shutdown_background_loop()
|
self.engine.shutdown_background_loop()
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
# Note that the return type of the superclass method is WRONG
|
||||||
raise ValueError(
|
async def register_model(self, model: Model) -> Model:
|
||||||
"You cannot dynamically add a model to a running vllm instance"
|
"""
|
||||||
)
|
Callback that is called when the server associates an inference endpoint
|
||||||
|
with an inference provider.
|
||||||
|
|
||||||
|
:param model: Object that encapsulates parameters necessary for identifying
|
||||||
|
a specific LLM.
|
||||||
|
|
||||||
|
:returns: The input ``Model`` object. It may or may not be permissible
|
||||||
|
to change fields before returning this object.
|
||||||
|
"""
|
||||||
|
log.info(f"Registering model {model.identifier} with vLLM inference provider.")
|
||||||
|
# The current version of this provided is hard-coded to serve only
|
||||||
|
# the model specified in the YAML config file.
|
||||||
|
configured_model = resolve_model(self.config.model)
|
||||||
|
registered_model = resolve_model(model.model_id)
|
||||||
|
|
||||||
|
if configured_model.core_model_id != registered_model.core_model_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"Requested model '{model.identifier}' is different from "
|
||||||
|
f"model '{self.config.model}' that this provider "
|
||||||
|
f"is configured to serve"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
|
@ -146,7 +163,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
@ -167,7 +184,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
log.info("Sampling params: %s", sampling_params)
|
log.info("Sampling params: %s", sampling_params)
|
||||||
request_id = _random_uuid()
|
request_id = _random_uuid()
|
||||||
|
|
||||||
prompt = await chat_completion_request_to_prompt(request, self.formatter)
|
prompt = await chat_completion_request_to_prompt(
|
||||||
|
request, self.config.model, self.formatter
|
||||||
|
)
|
||||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||||
results_generator = self.engine.generate(
|
results_generator = self.engine.generate(
|
||||||
prompt, vllm_sampling_params, request_id
|
prompt, vllm_sampling_params, request_id
|
||||||
|
|
|
@ -16,8 +16,6 @@ import torch
|
||||||
from llama_models.datatypes import Model
|
from llama_models.datatypes import Model
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.post_training import DatasetFormat
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
|
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
|
||||||
|
|
||||||
|
@ -27,6 +25,8 @@ from torchtune.models.llama3_1 import lora_llama3_1_8b
|
||||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||||
from torchtune.modules.transforms import Transform
|
from torchtune.modules.transforms import Transform
|
||||||
|
|
||||||
|
from llama_stack.apis.post_training import DatasetFormat
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
model_definition: Any
|
model_definition: Any
|
||||||
|
|
|
@ -14,6 +14,24 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from torchtune import modules, training, utils as torchtune_utils
|
||||||
|
from torchtune.data import padded_collate_sft
|
||||||
|
|
||||||
|
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||||
|
from torchtune.modules.peft import (
|
||||||
|
get_adapter_params,
|
||||||
|
get_adapter_state_dict,
|
||||||
|
get_lora_module_names,
|
||||||
|
get_merged_lora_ckpt,
|
||||||
|
set_trainable_params,
|
||||||
|
validate_missing_and_unexpected_for_lora,
|
||||||
|
)
|
||||||
|
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||||
|
from torchtune.training.metric_logging import DiskLogger
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from llama_stack.apis.common.training_types import PostTrainingMetric
|
from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
@ -41,24 +59,6 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
TorchtunePostTrainingConfig,
|
TorchtunePostTrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||||
from torch import nn
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
|
||||||
from torchtune import modules, training, utils as torchtune_utils
|
|
||||||
from torchtune.data import padded_collate_sft
|
|
||||||
|
|
||||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
|
||||||
from torchtune.modules.peft import (
|
|
||||||
get_adapter_params,
|
|
||||||
get_adapter_state_dict,
|
|
||||||
get_lora_module_names,
|
|
||||||
get_merged_lora_ckpt,
|
|
||||||
set_trainable_params,
|
|
||||||
validate_missing_and_unexpected_for_lora,
|
|
||||||
)
|
|
||||||
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
|
||||||
from torchtune.training.metric_logging import DiskLogger
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .config import CodeShieldConfig
|
from .config import CodeScannerConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: CodeShieldConfig, deps):
|
async def get_provider_impl(config: CodeScannerConfig, deps):
|
||||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||||
|
|
||||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||||
|
|
|
@ -156,7 +156,7 @@ class BraintrustScoringImpl(
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.openai_api_key:
|
if provider_data is None or not provider_data.openai_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": <your api key>}'
|
'Pass OpenAI API Key in the header X-LlamaStack-Provider-Data as { "openai_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
self.config.openai_api_key = provider_data.openai_api_key
|
self.config.openai_api_key = provider_data.openai_api_key
|
||||||
|
|
||||||
|
|
5
llama_stack/providers/inline/tool_runtime/__init__.py
Normal file
5
llama_stack/providers/inline/tool_runtime/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||||
|
from .config import CodeInterpreterToolConfig
|
||||||
|
|
||||||
|
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
|
||||||
|
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
Tool,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
from .code_execution import CodeExecutionContext, CodeExecutionRequest, CodeExecutor
|
||||||
|
from .config import CodeInterpreterToolConfig
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
|
def __init__(self, config: CodeInterpreterToolConfig):
|
||||||
|
self.config = config
|
||||||
|
ctx = CodeExecutionContext(
|
||||||
|
matplotlib_dump_dir=tempfile.mkdtemp(),
|
||||||
|
)
|
||||||
|
self.code_executor = CodeExecutor(ctx)
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_tool(self, tool: Tool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="code_interpreter",
|
||||||
|
description="Execute code",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="code",
|
||||||
|
description="The code to execute",
|
||||||
|
parameter_type="string",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
script = args["code"]
|
||||||
|
req = CodeExecutionRequest(scripts=[script])
|
||||||
|
res = self.code_executor.execute(req)
|
||||||
|
pieces = [res["process_status"]]
|
||||||
|
for out_type in ["stdout", "stderr"]:
|
||||||
|
res_out = res[out_type]
|
||||||
|
if res_out != "":
|
||||||
|
pieces.extend([f"[{out_type}]", res_out, f"[/{out_type}]"])
|
||||||
|
if out_type == "stderr":
|
||||||
|
log.error(f"ipython tool error: ↓\n{res_out}")
|
||||||
|
return ToolInvocationResult(content="\n".join(pieces))
|
|
@ -0,0 +1,11 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class CodeInterpreterToolConfig(BaseModel):
|
||||||
|
pass
|
20
llama_stack/providers/inline/tool_runtime/memory/__init__.py
Normal file
20
llama_stack/providers/inline/tool_runtime/memory/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
from .config import MemoryToolRuntimeConfig
|
||||||
|
from .memory import MemoryToolRuntimeImpl
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: MemoryToolRuntimeConfig, deps: Dict[str, Any]):
|
||||||
|
impl = MemoryToolRuntimeImpl(
|
||||||
|
config, deps[Api.memory], deps[Api.memory_banks], deps[Api.inference]
|
||||||
|
)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
90
llama_stack/providers/inline/tool_runtime/memory/config.py
Normal file
90
llama_stack/providers/inline/tool_runtime/memory/config.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Annotated, List, Literal, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class _MemoryBankConfigCommon(BaseModel):
|
||||||
|
bank_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal["vector"] = "vector"
|
||||||
|
|
||||||
|
|
||||||
|
class KeyValueMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal["keyvalue"] = "keyvalue"
|
||||||
|
keys: List[str] # what keys to focus on
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal["keyword"] = "keyword"
|
||||||
|
|
||||||
|
|
||||||
|
class GraphMemoryBankConfig(_MemoryBankConfigCommon):
|
||||||
|
type: Literal["graph"] = "graph"
|
||||||
|
entities: List[str] # what entities to focus on
|
||||||
|
|
||||||
|
|
||||||
|
MemoryBankConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
VectorMemoryBankConfig,
|
||||||
|
KeyValueMemoryBankConfig,
|
||||||
|
KeywordMemoryBankConfig,
|
||||||
|
GraphMemoryBankConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryQueryGenerator(Enum):
|
||||||
|
default = "default"
|
||||||
|
llm = "llm"
|
||||||
|
custom = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultMemoryQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal[MemoryQueryGenerator.default.value] = (
|
||||||
|
MemoryQueryGenerator.default.value
|
||||||
|
)
|
||||||
|
sep: str = " "
|
||||||
|
|
||||||
|
|
||||||
|
class LLMMemoryQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal[MemoryQueryGenerator.llm.value] = MemoryQueryGenerator.llm.value
|
||||||
|
model: str
|
||||||
|
template: str
|
||||||
|
|
||||||
|
|
||||||
|
class CustomMemoryQueryGeneratorConfig(BaseModel):
|
||||||
|
type: Literal[MemoryQueryGenerator.custom.value] = MemoryQueryGenerator.custom.value
|
||||||
|
|
||||||
|
|
||||||
|
MemoryQueryGeneratorConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
DefaultMemoryQueryGeneratorConfig,
|
||||||
|
LLMMemoryQueryGeneratorConfig,
|
||||||
|
CustomMemoryQueryGeneratorConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryToolConfig(BaseModel):
|
||||||
|
memory_bank_configs: List[MemoryBankConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryToolRuntimeConfig(BaseModel):
|
||||||
|
# This config defines how a query is generated using the messages
|
||||||
|
# for memory bank retrieval.
|
||||||
|
query_generator_config: MemoryQueryGeneratorConfig = Field(
|
||||||
|
default=DefaultMemoryQueryGeneratorConfig()
|
||||||
|
)
|
||||||
|
max_tokens_in_context: int = 4096
|
||||||
|
max_chunks: int = 5
|
|
@ -4,25 +4,29 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
from llama_stack.apis.inference import UserMessage
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
interleaved_content_as_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .config import (
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
DefaultMemoryQueryGeneratorConfig,
|
||||||
LLMMemoryQueryGeneratorConfig,
|
LLMMemoryQueryGeneratorConfig,
|
||||||
MemoryQueryGenerator,
|
MemoryQueryGenerator,
|
||||||
MemoryQueryGeneratorConfig,
|
MemoryQueryGeneratorConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import Message, UserMessage
|
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
||||||
interleaved_content_as_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_rag_query(
|
async def generate_rag_query(
|
||||||
config: MemoryQueryGeneratorConfig,
|
config: MemoryQueryGeneratorConfig,
|
||||||
messages: List[Message],
|
messages: List[InterleavedContent],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -40,21 +44,26 @@ async def generate_rag_query(
|
||||||
|
|
||||||
async def default_rag_query_generator(
|
async def default_rag_query_generator(
|
||||||
config: DefaultMemoryQueryGeneratorConfig,
|
config: DefaultMemoryQueryGeneratorConfig,
|
||||||
messages: List[Message],
|
messages: List[InterleavedContent],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return config.sep.join(interleaved_content_as_str(m.content) for m in messages)
|
return config.sep.join(interleaved_content_as_str(m) for m in messages)
|
||||||
|
|
||||||
|
|
||||||
async def llm_rag_query_generator(
|
async def llm_rag_query_generator(
|
||||||
config: LLMMemoryQueryGeneratorConfig,
|
config: LLMMemoryQueryGeneratorConfig,
|
||||||
messages: List[Message],
|
messages: List[InterleavedContent],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
assert "inference_api" in kwargs, "LLMRAGQueryGenerator needs inference_api"
|
||||||
inference_api = kwargs["inference_api"]
|
inference_api = kwargs["inference_api"]
|
||||||
|
|
||||||
m_dict = {"messages": [m.model_dump() for m in messages]}
|
m_dict = {
|
||||||
|
"messages": [
|
||||||
|
message.model_dump() if isinstance(message, BaseModel) else message
|
||||||
|
for message in messages
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
template = Template(config.template)
|
template = Template(config.template)
|
||||||
content = template.render(m_dict)
|
content = template.render(m_dict)
|
146
llama_stack/providers/inline/tool_runtime/memory/memory.py
Normal file
146
llama_stack/providers/inline/tool_runtime/memory/memory.py
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
|
from llama_stack.apis.memory import Memory, QueryDocumentsResponse
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
|
|
||||||
|
from .config import MemoryToolConfig, MemoryToolRuntimeConfig
|
||||||
|
from .context_retriever import generate_rag_query
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def make_random_string(length: int = 8):
|
||||||
|
return "".join(
|
||||||
|
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MemoryToolRuntimeConfig,
|
||||||
|
memory_api: Memory,
|
||||||
|
memory_banks_api: MemoryBanks,
|
||||||
|
inference_api: Inference,
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.memory_api = memory_api
|
||||||
|
self.memory_banks_api = memory_banks_api
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="query_memory",
|
||||||
|
description="Retrieve context from memory",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="messages",
|
||||||
|
description="The input messages to search for",
|
||||||
|
parameter_type="array",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _retrieve_context(
|
||||||
|
self, input_messages: List[InterleavedContent], bank_ids: List[str]
|
||||||
|
) -> Optional[List[InterleavedContent]]:
|
||||||
|
if not bank_ids:
|
||||||
|
return None
|
||||||
|
query = await generate_rag_query(
|
||||||
|
self.config.query_generator_config,
|
||||||
|
input_messages,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
tasks = [
|
||||||
|
self.memory_api.query_documents(
|
||||||
|
bank_id=bank_id,
|
||||||
|
query=query,
|
||||||
|
params={
|
||||||
|
"max_chunks": self.config.max_chunks,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for bank_id in bank_ids
|
||||||
|
]
|
||||||
|
results: List[QueryDocumentsResponse] = await asyncio.gather(*tasks)
|
||||||
|
chunks = [c for r in results for c in r.chunks]
|
||||||
|
scores = [s for r in results for s in r.scores]
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# sort by score
|
||||||
|
chunks, scores = zip(
|
||||||
|
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = 0
|
||||||
|
picked = []
|
||||||
|
for c in chunks[: self.config.max_chunks]:
|
||||||
|
tokens += c.token_count
|
||||||
|
if tokens > self.config.max_tokens_in_context:
|
||||||
|
log.error(
|
||||||
|
f"Using {len(picked)} chunks; reached max tokens in context: {tokens}",
|
||||||
|
)
|
||||||
|
break
|
||||||
|
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||||
|
|
||||||
|
return [
|
||||||
|
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||||
|
*picked,
|
||||||
|
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
tool = await self.tool_store.get_tool(tool_name)
|
||||||
|
tool_group = await self.tool_store.get_tool_group(tool.toolgroup_id)
|
||||||
|
final_args = tool_group.args or {}
|
||||||
|
final_args.update(args)
|
||||||
|
config = MemoryToolConfig()
|
||||||
|
if tool.metadata and tool.metadata.get("config") is not None:
|
||||||
|
config = MemoryToolConfig(**tool.metadata["config"])
|
||||||
|
if "memory_bank_ids" in final_args:
|
||||||
|
bank_ids = final_args["memory_bank_ids"]
|
||||||
|
else:
|
||||||
|
bank_ids = [
|
||||||
|
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||||
|
]
|
||||||
|
if "messages" not in final_args:
|
||||||
|
raise ValueError("messages are required")
|
||||||
|
context = await self._retrieve_context(
|
||||||
|
final_args["messages"],
|
||||||
|
bank_ids,
|
||||||
|
)
|
||||||
|
if context is None:
|
||||||
|
context = []
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=concat_interleaved_content(context), error_code=0
|
||||||
|
)
|
|
@ -35,6 +35,8 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.safety,
|
Api.safety,
|
||||||
Api.memory,
|
Api.memory,
|
||||||
Api.memory_banks,
|
Api.memory_banks,
|
||||||
|
Api.tool_runtime,
|
||||||
|
Api.tool_groups,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
|
|
@ -19,11 +19,58 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
return [
|
return [
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
provider_type="inline::brave-search",
|
provider_type="inline::memory-runtime",
|
||||||
pip_packages=[],
|
pip_packages=[],
|
||||||
module="llama_stack.providers.inline.tool_runtime.brave_search",
|
module="llama_stack.providers.inline.tool_runtime.memory",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig",
|
||||||
provider_data_validator="llama_stack.providers.inline.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
api_dependencies=[Api.memory, Api.memory_banks, Api.inference],
|
||||||
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
provider_type="inline::code-interpreter",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_stack.providers.inline.tool_runtime.code_interpreter",
|
||||||
|
config_class="llama_stack.providers.inline.tool_runtime.code_interpreter.config.CodeInterpreterToolConfig",
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="brave-search",
|
||||||
|
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
||||||
|
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||||
|
pip_packages=["requests"],
|
||||||
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="bing-search",
|
||||||
|
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||||
|
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||||
|
pip_packages=["requests"],
|
||||||
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="tavily-search",
|
||||||
|
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
||||||
|
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||||
|
pip_packages=["requests"],
|
||||||
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.tool_runtime,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="wolfram-alpha",
|
||||||
|
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||||
|
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||||
|
pip_packages=["requests"],
|
||||||
|
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
|
|
|
@ -10,7 +10,6 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -30,7 +29,6 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
@ -47,7 +45,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL_ALIASES = [
|
MODEL_ALIASES = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta.llama3-1-8b-instruct-v1:0",
|
"meta.llama3-1-8b-instruct-v1:0",
|
||||||
|
@ -101,7 +98,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
|
|
|
@ -7,11 +7,8 @@
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from cerebras.cloud.sdk import AsyncCerebras
|
from cerebras.cloud.sdk import AsyncCerebras
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -29,7 +26,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
@ -48,7 +44,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import CerebrasImplConfig
|
from .config import CerebrasImplConfig
|
||||||
|
|
||||||
|
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"llama3.1-8b",
|
"llama3.1-8b",
|
||||||
|
@ -130,7 +125,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
|
|
@ -7,11 +7,8 @@
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -28,7 +25,6 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
@ -44,7 +40,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import DatabricksImplConfig
|
from .config import DatabricksImplConfig
|
||||||
|
|
||||||
|
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"databricks-meta-llama-3-1-70b-instruct",
|
"databricks-meta-llama-3-1-70b-instruct",
|
||||||
|
@ -91,7 +86,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
|
@ -22,7 +22,7 @@ class FireworksImplConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.fireworks.ai/inference/v1",
|
"url": "https://api.fireworks.ai/inference/v1",
|
||||||
"api_key": "${env.FIREWORKS_API_KEY}",
|
"api_key": "${env.FIREWORKS_API_KEY}",
|
||||||
|
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
@ -52,46 +51,45 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
|
|
||||||
|
|
||||||
MODEL_ALIASES = [
|
MODEL_ALIASES = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p1-8b-instruct",
|
"accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||||
CoreModelId.llama3_1_8b_instruct.value,
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p1-70b-instruct",
|
"accounts/fireworks/models/llama-v3p1-70b-instruct",
|
||||||
CoreModelId.llama3_1_70b_instruct.value,
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p1-405b-instruct",
|
"accounts/fireworks/models/llama-v3p1-405b-instruct",
|
||||||
CoreModelId.llama3_1_405b_instruct.value,
|
CoreModelId.llama3_1_405b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p2-1b-instruct",
|
"accounts/fireworks/models/llama-v3p2-1b-instruct",
|
||||||
CoreModelId.llama3_2_1b_instruct.value,
|
CoreModelId.llama3_2_1b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p2-3b-instruct",
|
"accounts/fireworks/models/llama-v3p2-3b-instruct",
|
||||||
CoreModelId.llama3_2_3b_instruct.value,
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p2-11b-vision-instruct",
|
"accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
|
||||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p2-90b-vision-instruct",
|
"accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
|
||||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-v3p3-70b-instruct",
|
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||||
CoreModelId.llama3_3_70b_instruct.value,
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-guard-3-8b",
|
"accounts/fireworks/models/llama-guard-3-8b",
|
||||||
CoreModelId.llama_guard_3_8b.value,
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
),
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"fireworks/llama-guard-3-11b-vision",
|
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||||
CoreModelId.llama_guard_3_11b_vision.value,
|
CoreModelId.llama_guard_3_11b_vision.value,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
@ -118,7 +116,7 @@ class FireworksInferenceAdapter(
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.fireworks_api_key:
|
if provider_data is None or not provider_data.fireworks_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
'Pass Fireworks API Key in the header X-LlamaStack-Provider-Data as { "fireworks_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
return provider_data.fireworks_api_key
|
return provider_data.fireworks_api_key
|
||||||
|
|
||||||
|
@ -198,7 +196,7 @@ class FireworksInferenceAdapter(
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
import warnings
|
import warnings
|
||||||
from typing import AsyncIterator, List, Optional, Union
|
from typing import AsyncIterator, List, Optional, Union
|
||||||
|
|
||||||
|
import groq
|
||||||
from groq import Groq
|
from groq import Groq
|
||||||
from llama_models.datatypes import SamplingParams
|
from llama_models.datatypes import SamplingParams
|
||||||
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
|
from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat
|
||||||
|
@ -33,6 +34,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias_with_just_provider_model_id,
|
build_model_alias_with_just_provider_model_id,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .groq_utils import (
|
from .groq_utils import (
|
||||||
convert_chat_completion_request,
|
convert_chat_completion_request,
|
||||||
convert_chat_completion_response,
|
convert_chat_completion_response,
|
||||||
|
@ -94,9 +96,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
ToolPromptFormat
|
|
||||||
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
|
@ -124,7 +124,16 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self._get_client().chat.completions.create(**request)
|
try:
|
||||||
|
response = self._get_client().chat.completions.create(**request)
|
||||||
|
except groq.BadRequestError as e:
|
||||||
|
if e.body.get("error", {}).get("code") == "tool_use_failed":
|
||||||
|
# For smaller models, Groq may fail to call a tool even when the request is well formed
|
||||||
|
raise ValueError(
|
||||||
|
"Groq failed to call a tool", e.body.get("error", {})
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return convert_chat_completion_response_stream(response)
|
return convert_chat_completion_response_stream(response)
|
||||||
|
@ -145,6 +154,6 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.groq_api_key:
|
if provider_data is None or not provider_data.groq_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Groq API Key in the header X-LlamaStack-ProviderData as { "groq_api_key": "<your api key>" }'
|
'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "<your api key>" }'
|
||||||
)
|
)
|
||||||
return Groq(api_key=provider_data.groq_api_key)
|
return Groq(api_key=provider_data.groq_api_key)
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from typing import AsyncGenerator, Literal
|
from typing import AsyncGenerator, Literal
|
||||||
|
|
||||||
|
@ -14,14 +15,20 @@ from groq.types.chat.chat_completion_assistant_message_param import (
|
||||||
)
|
)
|
||||||
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||||
|
from groq.types.chat.chat_completion_message_tool_call import (
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
)
|
||||||
from groq.types.chat.chat_completion_system_message_param import (
|
from groq.types.chat.chat_completion_system_message_param import (
|
||||||
ChatCompletionSystemMessageParam,
|
ChatCompletionSystemMessageParam,
|
||||||
)
|
)
|
||||||
|
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
||||||
from groq.types.chat.chat_completion_user_message_param import (
|
from groq.types.chat.chat_completion_user_message_param import (
|
||||||
ChatCompletionUserMessageParam,
|
ChatCompletionUserMessageParam,
|
||||||
)
|
)
|
||||||
|
|
||||||
from groq.types.chat.completion_create_params import CompletionCreateParams
|
from groq.types.chat.completion_create_params import CompletionCreateParams
|
||||||
|
from groq.types.shared.function_definition import FunctionDefinition
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import ToolParamDefinition
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
@ -32,6 +39,11 @@ from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
Message,
|
Message,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,8 +71,8 @@ def convert_chat_completion_request(
|
||||||
# so we exclude it for now
|
# so we exclude it for now
|
||||||
warnings.warn("repetition_penalty is not supported")
|
warnings.warn("repetition_penalty is not supported")
|
||||||
|
|
||||||
if request.tools:
|
if request.tool_prompt_format != ToolPromptFormat.json:
|
||||||
warnings.warn("tools are not supported yet")
|
warnings.warn("tool_prompt_format is not used by Groq. Ignoring.")
|
||||||
|
|
||||||
return CompletionCreateParams(
|
return CompletionCreateParams(
|
||||||
model=request.model,
|
model=request.model,
|
||||||
|
@ -71,6 +83,8 @@ def convert_chat_completion_request(
|
||||||
max_tokens=request.sampling_params.max_tokens or None,
|
max_tokens=request.sampling_params.max_tokens or None,
|
||||||
temperature=request.sampling_params.temperature,
|
temperature=request.sampling_params.temperature,
|
||||||
top_p=request.sampling_params.top_p,
|
top_p=request.sampling_params.top_p,
|
||||||
|
tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []],
|
||||||
|
tool_choice=request.tool_choice.value if request.tool_choice else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,17 +101,64 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam:
|
||||||
raise ValueError(f"Invalid message role: {message.role}")
|
raise ValueError(f"Invalid message role: {message.role}")
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
|
||||||
|
# Groq requires a description for function tools
|
||||||
|
if tool_definition.description is None:
|
||||||
|
raise AssertionError("tool_definition.description is required")
|
||||||
|
|
||||||
|
tool_parameters = tool_definition.parameters or {}
|
||||||
|
return ChatCompletionToolParam(
|
||||||
|
type="function",
|
||||||
|
function=FunctionDefinition(
|
||||||
|
name=tool_definition.tool_name,
|
||||||
|
description=tool_definition.description,
|
||||||
|
parameters={
|
||||||
|
key: _convert_groq_tool_parameter(param)
|
||||||
|
for key, param in tool_parameters.items()
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict:
|
||||||
|
param = {
|
||||||
|
"type": tool_parameter.param_type,
|
||||||
|
}
|
||||||
|
if tool_parameter.description is not None:
|
||||||
|
param["description"] = tool_parameter.description
|
||||||
|
if tool_parameter.required is not None:
|
||||||
|
param["required"] = tool_parameter.required
|
||||||
|
if tool_parameter.default is not None:
|
||||||
|
param["default"] = tool_parameter.default
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
def convert_chat_completion_response(
|
def convert_chat_completion_response(
|
||||||
response: ChatCompletion,
|
response: ChatCompletion,
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
# groq only supports n=1 at time of writing, so there is only one choice
|
# groq only supports n=1 at time of writing, so there is only one choice
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
return ChatCompletionResponse(
|
if choice.finish_reason == "tool_calls":
|
||||||
completion_message=CompletionMessage(
|
tool_calls = [
|
||||||
content=choice.message.content,
|
_convert_groq_tool_call(tool_call)
|
||||||
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
for tool_call in choice.message.tool_calls
|
||||||
),
|
]
|
||||||
)
|
return ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
stop_reason=StopReason.end_of_message,
|
||||||
|
# Content is not optional
|
||||||
|
content="",
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
content=choice.message.content,
|
||||||
|
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _map_finish_reason_to_stop_reason(
|
def _map_finish_reason_to_stop_reason(
|
||||||
|
@ -116,7 +177,7 @@ def _map_finish_reason_to_stop_reason(
|
||||||
elif finish_reason == "length":
|
elif finish_reason == "length":
|
||||||
return StopReason.out_of_tokens
|
return StopReason.out_of_tokens
|
||||||
elif finish_reason == "tool_calls":
|
elif finish_reason == "tool_calls":
|
||||||
raise NotImplementedError("tool_calls is not supported yet")
|
return StopReason.end_of_message
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid finish reason: {finish_reason}")
|
raise ValueError(f"Invalid finish reason: {finish_reason}")
|
||||||
|
|
||||||
|
@ -129,25 +190,50 @@ async def convert_chat_completion_response_stream(
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
choice = chunk.choices[0]
|
choice = chunk.choices[0]
|
||||||
|
|
||||||
# We assume there's only one finish_reason for the entire stream.
|
|
||||||
# We collect the last finish_reason
|
|
||||||
if choice.finish_reason:
|
if choice.finish_reason:
|
||||||
stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason)
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
yield ChatCompletionResponseStreamChunk(
|
event_type=ChatCompletionResponseEventType.complete,
|
||||||
event=ChatCompletionResponseEvent(
|
delta=choice.delta.content or "",
|
||||||
event_type=event_type,
|
logprobs=None,
|
||||||
delta=choice.delta.content or "",
|
stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason),
|
||||||
logprobs=None,
|
)
|
||||||
|
)
|
||||||
|
elif choice.delta.tool_calls:
|
||||||
|
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
|
||||||
|
if len(choice.delta.tool_calls) > 1:
|
||||||
|
warnings.warn(
|
||||||
|
"Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest."
|
||||||
|
)
|
||||||
|
|
||||||
|
# We assume Groq produces fully formed tool calls for each chunk
|
||||||
|
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=event_type,
|
||||||
|
delta=ToolCallDelta(
|
||||||
|
content=tool_call,
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
event=ChatCompletionResponseEvent(
|
||||||
|
event_type=event_type,
|
||||||
|
delta=choice.delta.content or "",
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
event_type = ChatCompletionResponseEventType.progress
|
event_type = ChatCompletionResponseEventType.progress
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
|
||||||
event=ChatCompletionResponseEvent(
|
def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
|
||||||
event_type=ChatCompletionResponseEventType.complete,
|
return ToolCall(
|
||||||
delta="",
|
call_id=tool_call.id,
|
||||||
logprobs=None,
|
tool_name=tool_call.function.name,
|
||||||
stop_reason=stop_reason,
|
# Note that Groq may return a string that is not valid JSON here
|
||||||
)
|
# So this may raise a 500 error. Going to leave this as is to see
|
||||||
|
# how big of an issue this is and what we can do about it.
|
||||||
|
arguments=json.loads(tool_call.function.arguments),
|
||||||
)
|
)
|
||||||
|
|
|
@ -175,9 +175,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
ToolPromptFormat
|
|
||||||
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
|
|
|
@ -144,7 +144,7 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
||||||
message = UserMessage(**message)
|
message = UserMessage(**message)
|
||||||
elif message["role"] == "assistant":
|
elif message["role"] == "assistant":
|
||||||
message = CompletionMessage(**message)
|
message = CompletionMessage(**message)
|
||||||
elif message["role"] == "ipython":
|
elif message["role"] == "tool":
|
||||||
message = ToolResponseMessage(**message)
|
message = ToolResponseMessage(**message)
|
||||||
elif message["role"] == "system":
|
elif message["role"] == "system":
|
||||||
message = SystemMessage(**message)
|
message = SystemMessage(**message)
|
||||||
|
|
|
@ -9,7 +9,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
@ -35,7 +34,6 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
build_model_alias_with_just_provider_model_id,
|
build_model_alias_with_just_provider_model_id,
|
||||||
|
@ -222,7 +220,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
|
@ -30,13 +30,11 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
|
@ -205,7 +203,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
|
|
@ -7,11 +7,8 @@
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -53,7 +50,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
|
||||||
MODEL_ALIASES = [
|
MODEL_ALIASES = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||||
|
@ -79,6 +75,10 @@ MODEL_ALIASES = [
|
||||||
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
|
||||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
|
build_model_alias(
|
||||||
|
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||||
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
|
),
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta-llama/Meta-Llama-Guard-3-8B",
|
"meta-llama/Meta-Llama-Guard-3-8B",
|
||||||
CoreModelId.llama_guard_3_8b.value,
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
|
@ -135,7 +135,7 @@ class TogetherInferenceAdapter(
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.together_api_key:
|
if provider_data is None or not provider_data.together_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
together_api_key = provider_data.together_api_key
|
together_api_key = provider_data.together_api_key
|
||||||
return Together(api_key=together_api_key)
|
return Together(api_key=together_api_key)
|
||||||
|
@ -184,7 +184,7 @@ class TogetherInferenceAdapter(
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
|
|
@ -10,7 +10,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -33,7 +32,6 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_alias,
|
build_model_alias,
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
@ -54,7 +52,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import VLLMInferenceAdapterConfig
|
from .config import VLLMInferenceAdapterConfig
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,7 +102,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
5
llama_stack/providers/remote/tool_runtime/__init__.py
Normal file
5
llama_stack/providers/remote/tool_runtime/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .bing_search import BingSearchToolRuntimeImpl
|
||||||
|
from .config import BingSearchToolConfig
|
||||||
|
|
||||||
|
__all__ = ["BingSearchToolConfig", "BingSearchToolRuntimeImpl"]
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BingSearchToolProviderDataValidator(BaseModel):
|
||||||
|
bing_search_api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: BingSearchToolConfig, _deps):
|
||||||
|
impl = BingSearchToolRuntimeImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,114 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
Tool,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
from .config import BingSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BingSearchToolRuntimeImpl(
|
||||||
|
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
|
||||||
|
):
|
||||||
|
def __init__(self, config: BingSearchToolConfig):
|
||||||
|
self.config = config
|
||||||
|
self.url = "https://api.bing.microsoft.com/v7.0/search"
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_tool(self, tool: Tool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _get_api_key(self) -> str:
|
||||||
|
if self.config.api_key:
|
||||||
|
return self.config.api_key
|
||||||
|
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.bing_search_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Bing Search API Key in the header X-LlamaStack-Provider-Data as { "bing_search_api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
return provider_data.bing_search_api_key
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="web_search",
|
||||||
|
description="Search the web using Bing Search API",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="query",
|
||||||
|
description="The query to search for",
|
||||||
|
parameter_type="string",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def invoke_tool(
|
||||||
|
self, tool_name: str, args: Dict[str, Any]
|
||||||
|
) -> ToolInvocationResult:
|
||||||
|
api_key = self._get_api_key()
|
||||||
|
headers = {
|
||||||
|
"Ocp-Apim-Subscription-Key": api_key,
|
||||||
|
}
|
||||||
|
params = {
|
||||||
|
"count": self.config.top_k,
|
||||||
|
"textDecorations": True,
|
||||||
|
"textFormat": "HTML",
|
||||||
|
"q": args["query"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
url=self.url,
|
||||||
|
params=params,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return ToolInvocationResult(
|
||||||
|
content=json.dumps(self._clean_response(response.json()))
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clean_response(self, search_response):
|
||||||
|
clean_response = []
|
||||||
|
query = search_response["queryContext"]["originalQuery"]
|
||||||
|
if "webPages" in search_response:
|
||||||
|
pages = search_response["webPages"]["value"]
|
||||||
|
for p in pages:
|
||||||
|
selected_keys = {"name", "url", "snippet"}
|
||||||
|
clean_response.append(
|
||||||
|
{k: v for k, v in p.items() if k in selected_keys}
|
||||||
|
)
|
||||||
|
if "news" in search_response:
|
||||||
|
clean_news = []
|
||||||
|
news = search_response["news"]["value"]
|
||||||
|
for n in news:
|
||||||
|
selected_keys = {"name", "url", "description"}
|
||||||
|
clean_news.append({k: v for k, v in n.items() if k in selected_keys})
|
||||||
|
|
||||||
|
clean_response.append(clean_news)
|
||||||
|
|
||||||
|
return {"query": query, "top_k": clean_response}
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BingSearchToolConfig(BaseModel):
|
||||||
|
"""Configuration for Bing Search Tool Runtime"""
|
||||||
|
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
top_k: int = 3
|
|
@ -11,10 +11,10 @@ from .config import BraveSearchToolConfig
|
||||||
|
|
||||||
|
|
||||||
class BraveSearchToolProviderDataValidator(BaseModel):
|
class BraveSearchToolProviderDataValidator(BaseModel):
|
||||||
api_key: str
|
brave_search_api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
|
async def get_adapter_impl(config: BraveSearchToolConfig, _deps):
|
||||||
impl = BraveSearchToolRuntimeImpl(config)
|
impl = BraveSearchToolRuntimeImpl(config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
|
@ -4,11 +4,19 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
Tool,
|
||||||
|
ToolDef,
|
||||||
|
ToolInvocationResult,
|
||||||
|
ToolParameter,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
|
@ -25,8 +33,7 @@ class BraveSearchToolRuntimeImpl(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_tool(self, tool: Tool):
|
async def register_tool(self, tool: Tool):
|
||||||
if tool.identifier != "brave_search":
|
pass
|
||||||
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
|
|
||||||
|
|
||||||
async def unregister_tool(self, tool_id: str) -> None:
|
async def unregister_tool(self, tool_id: str) -> None:
|
||||||
return
|
return
|
||||||
|
@ -36,14 +43,29 @@ class BraveSearchToolRuntimeImpl(
|
||||||
return self.config.api_key
|
return self.config.api_key
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.api_key:
|
if provider_data is None or not provider_data.brave_search_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
|
'Pass Search provider\'s API Key in the header X-LlamaStack-Provider-Data as { "brave_search_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
return provider_data.api_key
|
return provider_data.brave_search_api_key
|
||||||
|
|
||||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
|
async def list_runtime_tools(
|
||||||
raise NotImplementedError("Brave search tool group not supported")
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
|
) -> List[ToolDef]:
|
||||||
|
return [
|
||||||
|
ToolDef(
|
||||||
|
name="web_search",
|
||||||
|
description="Search the web for information",
|
||||||
|
parameters=[
|
||||||
|
ToolParameter(
|
||||||
|
name="query",
|
||||||
|
description="The query to search for",
|
||||||
|
parameter_type="string",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
built_in_type=BuiltinTool.brave_search,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
self, tool_name: str, args: Dict[str, Any]
|
self, tool_name: str, args: Dict[str, Any]
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -18,3 +18,10 @@ class BraveSearchToolConfig(BaseModel):
|
||||||
default=3,
|
default=3,
|
||||||
description="The maximum number of results to return",
|
description="The maximum number of results to return",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"api_key": "${env.BRAVE_SEARCH_API_KEY:}",
|
||||||
|
"max_results": 3,
|
||||||
|
}
|
|
@ -4,22 +4,21 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
MCPToolGroupDef,
|
|
||||||
ToolDef,
|
ToolDef,
|
||||||
ToolGroupDef,
|
|
||||||
ToolInvocationResult,
|
ToolInvocationResult,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
|
||||||
from mcp import ClientSession
|
|
||||||
from mcp.client.sse import sse_client
|
|
||||||
|
|
||||||
from .config import ModelContextProtocolConfig
|
from .config import ModelContextProtocolConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,12 +29,14 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def discover_tools(self, tool_group: ToolGroupDef) -> List[ToolDef]:
|
async def list_runtime_tools(
|
||||||
if not isinstance(tool_group, MCPToolGroupDef):
|
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||||
raise ValueError(f"Unsupported tool group type: {type(tool_group)}")
|
) -> List[ToolDef]:
|
||||||
|
if mcp_endpoint is None:
|
||||||
|
raise ValueError("mcp_endpoint is required")
|
||||||
|
|
||||||
tools = []
|
tools = []
|
||||||
async with sse_client(tool_group.endpoint.uri) as streams:
|
async with sse_client(mcp_endpoint.uri) as streams:
|
||||||
async with ClientSession(*streams) as session:
|
async with ClientSession(*streams) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
tools_result = await session.list_tools()
|
tools_result = await session.list_tools()
|
||||||
|
@ -57,7 +58,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
description=tool.description,
|
description=tool.description,
|
||||||
parameters=parameters,
|
parameters=parameters,
|
||||||
metadata={
|
metadata={
|
||||||
"endpoint": tool_group.endpoint.uri,
|
"endpoint": mcp_endpoint.uri,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue