mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-20 19:56:59 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
51b68b4be6
234 changed files with 21943 additions and 7540 deletions
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
@ -2,4 +2,4 @@
|
||||||
|
|
||||||
# These owners will be the default owners for everything in
|
# These owners will be the default owners for everything in
|
||||||
# the repo. Unless a later match takes precedence,
|
# the repo. Unless a later match takes precedence,
|
||||||
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan @SLR722 @leseb
|
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning
|
||||||
|
|
2
.github/TRIAGERS.md
vendored
2
.github/TRIAGERS.md
vendored
|
@ -1,2 +1,2 @@
|
||||||
# This file documents Triage members in the Llama Stack community
|
# This file documents Triage members in the Llama Stack community
|
||||||
@franciscojavierarceo @leseb
|
@bbrowning @booxter @franciscojavierarceo @leseb
|
||||||
|
|
35
.github/workflows/integration-auth-tests.yml
vendored
35
.github/workflows/integration-auth-tests.yml
vendored
|
@ -28,12 +28,13 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
activate-environment: true
|
||||||
|
|
||||||
- name: Set Up Environment and Install Dependencies
|
- name: Set Up Environment and Install Dependencies
|
||||||
run: |
|
run: |
|
||||||
|
@ -43,7 +44,7 @@ jobs:
|
||||||
|
|
||||||
- name: Install minikube
|
- name: Install minikube
|
||||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
uses: medyagh/setup-minikube@latest
|
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
|
||||||
|
|
||||||
- name: Start minikube
|
- name: Start minikube
|
||||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||||
|
@ -74,32 +75,8 @@ jobs:
|
||||||
cat <<'EOF' > $run_dir/run.yaml
|
cat <<'EOF' > $run_dir/run.yaml
|
||||||
version: '2'
|
version: '2'
|
||||||
image_name: kube
|
image_name: kube
|
||||||
apis:
|
apis: []
|
||||||
- agents
|
providers: {}
|
||||||
- datasetio
|
|
||||||
- eval
|
|
||||||
- inference
|
|
||||||
- safety
|
|
||||||
- scoring
|
|
||||||
- telemetry
|
|
||||||
- tool_runtime
|
|
||||||
- vector_io
|
|
||||||
providers:
|
|
||||||
agents:
|
|
||||||
- provider_id: meta-reference
|
|
||||||
provider_type: inline::meta-reference
|
|
||||||
config:
|
|
||||||
persistence_store:
|
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
|
|
||||||
telemetry:
|
|
||||||
- provider_id: meta-reference
|
|
||||||
provider_type: inline::meta-reference
|
|
||||||
config:
|
|
||||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
|
||||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
|
||||||
server:
|
server:
|
||||||
port: 8321
|
port: 8321
|
||||||
EOF
|
EOF
|
||||||
|
|
20
.github/workflows/integration-tests.yml
vendored
20
.github/workflows/integration-tests.yml
vendored
|
@ -33,7 +33,7 @@ jobs:
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
activate-environment: true
|
activate-environment: true
|
||||||
|
@ -58,7 +58,7 @@ jobs:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 &
|
LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
if: matrix.client-type == 'http'
|
if: matrix.client-type == 'http'
|
||||||
|
@ -85,6 +85,11 @@ jobs:
|
||||||
echo "Ollama health check failed"
|
echo "Ollama health check failed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
- name: Check Storage and Memory Available Before Tests
|
||||||
|
if: ${{ always() }}
|
||||||
|
run: |
|
||||||
|
free -h
|
||||||
|
df -h
|
||||||
|
|
||||||
- name: Run Integration Tests
|
- name: Run Integration Tests
|
||||||
env:
|
env:
|
||||||
|
@ -100,13 +105,20 @@ jobs:
|
||||||
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
||||||
--embedding-model=all-MiniLM-L6-v2
|
--embedding-model=all-MiniLM-L6-v2
|
||||||
|
|
||||||
|
- name: Check Storage and Memory Available After Tests
|
||||||
|
if: ${{ always() }}
|
||||||
|
run: |
|
||||||
|
free -h
|
||||||
|
df -h
|
||||||
|
|
||||||
- name: Write ollama logs to file
|
- name: Write ollama logs to file
|
||||||
|
if: ${{ always() }}
|
||||||
run: |
|
run: |
|
||||||
sudo journalctl -u ollama.service > ollama.log
|
sudo journalctl -u ollama.service > ollama.log
|
||||||
|
|
||||||
- name: Upload all logs to artifacts
|
- name: Upload all logs to artifacts
|
||||||
if: always()
|
if: ${{ always() }}
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||||
with:
|
with:
|
||||||
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}
|
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}
|
||||||
path: |
|
path: |
|
||||||
|
|
18
.github/workflows/providers-build.yml
vendored
18
.github/workflows/providers-build.yml
vendored
|
@ -56,7 +56,7 @@ jobs:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ jobs:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
|
@ -120,7 +120,7 @@ jobs:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
|
@ -132,9 +132,9 @@ jobs:
|
||||||
|
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
yq -i '.image_type = "container"' llama_stack/templates/dev/build.yaml
|
yq -i '.image_type = "container"' llama_stack/templates/starter/build.yaml
|
||||||
yq -i '.image_name = "test"' llama_stack/templates/dev/build.yaml
|
yq -i '.image_name = "test"' llama_stack/templates/starter/build.yaml
|
||||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/dev/build.yaml
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/starter/build.yaml
|
||||||
|
|
||||||
- name: Inspect the container image entrypoint
|
- name: Inspect the container image entrypoint
|
||||||
run: |
|
run: |
|
||||||
|
@ -158,7 +158,7 @@ jobs:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
|
@ -174,14 +174,14 @@ jobs:
|
||||||
.image_type = "container" |
|
.image_type = "container" |
|
||||||
.image_name = "ubi9-test" |
|
.image_name = "ubi9-test" |
|
||||||
.distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest"
|
.distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest"
|
||||||
' llama_stack/templates/dev/build.yaml
|
' llama_stack/templates/starter/build.yaml
|
||||||
|
|
||||||
- name: Build dev container (UBI9)
|
- name: Build dev container (UBI9)
|
||||||
env:
|
env:
|
||||||
USE_COPY_NOT_MOUNT: "true"
|
USE_COPY_NOT_MOUNT: "true"
|
||||||
LLAMA_STACK_DIR: "."
|
LLAMA_STACK_DIR: "."
|
||||||
run: |
|
run: |
|
||||||
uv run llama stack build --config llama_stack/templates/dev/build.yaml
|
uv run llama stack build --config llama_stack/templates/starter/build.yaml
|
||||||
|
|
||||||
- name: Inspect UBI9 image
|
- name: Inspect UBI9 image
|
||||||
run: |
|
run: |
|
||||||
|
|
11
.github/workflows/test-external-providers.yml
vendored
11
.github/workflows/test-external-providers.yml
vendored
|
@ -23,10 +23,10 @@ jobs:
|
||||||
# container and point 'uv pip install' to the correct path...
|
# container and point 'uv pip install' to the correct path...
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
|
@ -47,8 +47,8 @@ jobs:
|
||||||
|
|
||||||
- name: Create provider configuration
|
- name: Create provider configuration
|
||||||
run: |
|
run: |
|
||||||
mkdir -p /tmp/providers.d/remote/inference
|
mkdir -p /home/runner/.llama/providers.d/remote/inference
|
||||||
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml
|
||||||
|
|
||||||
- name: Build distro from config file
|
- name: Build distro from config file
|
||||||
run: |
|
run: |
|
||||||
|
@ -66,7 +66,7 @@ jobs:
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
run: |
|
run: |
|
||||||
for i in {1..30}; do
|
for i in {1..30}; do
|
||||||
if ! grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
if ! grep -q "remote::custom_ollama from /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
||||||
echo "Waiting for Llama Stack server to load the provider..."
|
echo "Waiting for Llama Stack server to load the provider..."
|
||||||
sleep 1
|
sleep 1
|
||||||
else
|
else
|
||||||
|
@ -75,4 +75,5 @@ jobs:
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
echo "Provider failed to load"
|
echo "Provider failed to load"
|
||||||
|
cat server.log
|
||||||
exit 1
|
exit 1
|
||||||
|
|
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
@ -37,7 +37,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
- uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
- uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
enable-cache: false
|
enable-cache: false
|
||||||
|
|
9
.github/workflows/update-readthedocs.yml
vendored
9
.github/workflows/update-readthedocs.yml
vendored
|
@ -14,6 +14,8 @@ on:
|
||||||
- 'docs/**'
|
- 'docs/**'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
- '.github/workflows/update-readthedocs.yml'
|
- '.github/workflows/update-readthedocs.yml'
|
||||||
|
tags:
|
||||||
|
- '*'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
|
@ -41,7 +43,7 @@ jobs:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
|
|
||||||
- name: Install the latest version of uv
|
- name: Install the latest version of uv
|
||||||
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||||
|
|
||||||
- name: Sync with uv
|
- name: Sync with uv
|
||||||
run: uv sync --extra docs
|
run: uv sync --extra docs
|
||||||
|
@ -61,7 +63,10 @@ jobs:
|
||||||
|
|
||||||
response=$(curl -X POST \
|
response=$(curl -X POST \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d "{\"token\": \"$TOKEN\"}" \
|
-d "{
|
||||||
|
\"token\": \"$TOKEN\",
|
||||||
|
\"version\": \"$GITHUB_REF_NAME\"
|
||||||
|
}" \
|
||||||
https://readthedocs.org/api/v2/webhook/llama-stack/289768/)
|
https://readthedocs.org/api/v2/webhook/llama-stack/289768/)
|
||||||
|
|
||||||
echo "Response: $response"
|
echo "Response: $response"
|
||||||
|
|
|
@ -106,6 +106,14 @@ repos:
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^llama_stack/apis/|^docs/openapi_generator/
|
files: ^llama_stack/apis/|^docs/openapi_generator/
|
||||||
|
- id: check-workflows-use-hashes
|
||||||
|
name: Check GitHub Actions use SHA-pinned actions
|
||||||
|
entry: ./scripts/check-workflows-use-hashes.sh
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
require_serial: true
|
||||||
|
always_run: true
|
||||||
|
files: ^\.github/workflows/.*\.ya?ml$
|
||||||
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||||
|
|
21
CHANGELOG.md
21
CHANGELOG.md
|
@ -1,5 +1,26 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
# v0.2.5
|
||||||
|
Published on: 2025-05-04T20:16:49Z
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.2.4
|
||||||
|
Published on: 2025-04-29T17:26:01Z
|
||||||
|
|
||||||
|
## Highlights
|
||||||
|
|
||||||
|
* One-liner to install and run Llama Stack yay! by @reluctantfuturist in https://github.com/meta-llama/llama-stack/pull/1383
|
||||||
|
* support for NVIDIA NeMo datastore by @raspawar in https://github.com/meta-llama/llama-stack/pull/1852
|
||||||
|
* (yuge!) Kubernetes authentication by @leseb in https://github.com/meta-llama/llama-stack/pull/1778
|
||||||
|
* (yuge!) OpenAI Responses API by @bbrowning in https://github.com/meta-llama/llama-stack/pull/1989
|
||||||
|
* add api.llama provider, llama-guard-4 model by @ashwinb in https://github.com/meta-llama/llama-stack/pull/2058
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
# v0.2.3
|
# v0.2.3
|
||||||
Published on: 2025-04-25T22:46:21Z
|
Published on: 2025-04-25T22:46:21Z
|
||||||
|
|
||||||
|
|
|
@ -110,25 +110,9 @@ uv run pre-commit run --all-files
|
||||||
> [!CAUTION]
|
> [!CAUTION]
|
||||||
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
||||||
|
|
||||||
## Running unit tests
|
## Running tests
|
||||||
|
|
||||||
You can run the unit tests by running:
|
You can find the Llama Stack testing documentation here [here](tests/README.md).
|
||||||
|
|
||||||
```bash
|
|
||||||
source .venv/bin/activate
|
|
||||||
./scripts/unit-tests.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
If you'd like to run for a non-default version of Python (currently 3.10), pass `PYTHON_VERSION` variable as follows:
|
|
||||||
|
|
||||||
```
|
|
||||||
source .venv/bin/activate
|
|
||||||
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running integration tests
|
|
||||||
|
|
||||||
You can run integration tests following the instructions [here](tests/integration/README.md).
|
|
||||||
|
|
||||||
## Adding a new dependency to the project
|
## Adding a new dependency to the project
|
||||||
|
|
||||||
|
@ -153,6 +137,8 @@ uv sync
|
||||||
justification for bypassing the check.
|
justification for bypassing the check.
|
||||||
* When using `# type: ignore` to suppress a mypy warning, include a comment explaining the
|
* When using `# type: ignore` to suppress a mypy warning, include a comment explaining the
|
||||||
justification for bypassing the check.
|
justification for bypassing the check.
|
||||||
|
* Don't use unicode characters in the codebase. ASCII-only is preferred for compatibility or
|
||||||
|
readability reasons.
|
||||||
|
|
||||||
## Common Tasks
|
## Common Tasks
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
||||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
||||||
|
|
||||||
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
|
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
|
||||||
|
|
||||||
### ✨🎉 Llama 4 Support 🎉✨
|
### ✨🎉 Llama 4 Support 🎉✨
|
||||||
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
|
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
|
||||||
|
|
3952
docs/_static/llama-stack-spec.html
vendored
3952
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
2928
docs/_static/llama-stack-spec.yaml
vendored
2928
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
@ -1050,8 +1050,6 @@
|
||||||
"text/html": [
|
"text/html": [
|
||||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ToolGroup</span><span style=\"font-weight: bold\">(</span>\n",
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ToolGroup</span><span style=\"font-weight: bold\">(</span>\n",
|
||||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">identifier</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'builtin::code_interpreter'</span>,\n",
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">identifier</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'builtin::code_interpreter'</span>,\n",
|
||||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">provider_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'code-interpreter'</span>,\n",
|
|
||||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">provider_resource_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'builtin::code_interpreter'</span>,\n",
|
|
||||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">type</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'tool_group'</span>,\n",
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">type</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'tool_group'</span>,\n",
|
||||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">args</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>,\n",
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">args</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>,\n",
|
||||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">mcp_endpoint</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>\n",
|
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">mcp_endpoint</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>\n",
|
||||||
|
@ -1061,7 +1059,6 @@
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"\u001b[1;35mToolGroup\u001b[0m\u001b[1m(\u001b[0m\n",
|
"\u001b[1;35mToolGroup\u001b[0m\u001b[1m(\u001b[0m\n",
|
||||||
"\u001b[2;32m│ \u001b[0m\u001b[33midentifier\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n",
|
"\u001b[2;32m│ \u001b[0m\u001b[33midentifier\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n",
|
||||||
"\u001b[2;32m│ \u001b[0m\u001b[33mprovider_id\u001b[0m=\u001b[32m'code-interpreter'\u001b[0m,\n",
|
|
||||||
"\u001b[2;32m│ \u001b[0m\u001b[33mprovider_resource_id\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n",
|
"\u001b[2;32m│ \u001b[0m\u001b[33mprovider_resource_id\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n",
|
||||||
"\u001b[2;32m│ \u001b[0m\u001b[33mtype\u001b[0m=\u001b[32m'tool_group'\u001b[0m,\n",
|
"\u001b[2;32m│ \u001b[0m\u001b[33mtype\u001b[0m=\u001b[32m'tool_group'\u001b[0m,\n",
|
||||||
"\u001b[2;32m│ \u001b[0m\u001b[33margs\u001b[0m=\u001b[3;35mNone\u001b[0m,\n",
|
"\u001b[2;32m│ \u001b[0m\u001b[33margs\u001b[0m=\u001b[3;35mNone\u001b[0m,\n",
|
||||||
|
|
|
@ -337,9 +337,6 @@
|
||||||
" provider_id: tavily-search\n",
|
" provider_id: tavily-search\n",
|
||||||
" provider_type: remote::tavily-search\n",
|
" provider_type: remote::tavily-search\n",
|
||||||
" - config: <span style=\"font-weight: bold\">{}</span>\n",
|
" - config: <span style=\"font-weight: bold\">{}</span>\n",
|
||||||
" provider_id: code-interpreter\n",
|
|
||||||
" provider_type: inlin<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e::c</span>ode-interpreter\n",
|
|
||||||
" - config: <span style=\"font-weight: bold\">{}</span>\n",
|
|
||||||
" provider_id: rag-runtime\n",
|
" provider_id: rag-runtime\n",
|
||||||
" provider_type: inline::rag-runtime\n",
|
" provider_type: inline::rag-runtime\n",
|
||||||
" - config: <span style=\"font-weight: bold\">{}</span>\n",
|
" - config: <span style=\"font-weight: bold\">{}</span>\n",
|
||||||
|
@ -378,10 +375,6 @@
|
||||||
" toolgroup_id: builtin::rag\n",
|
" toolgroup_id: builtin::rag\n",
|
||||||
"- args: null\n",
|
"- args: null\n",
|
||||||
" mcp_endpoint: null\n",
|
" mcp_endpoint: null\n",
|
||||||
" provider_id: code-interpreter\n",
|
|
||||||
" toolgroup_id: builtin::code_interpreter\n",
|
|
||||||
"- args: null\n",
|
|
||||||
" mcp_endpoint: null\n",
|
|
||||||
" provider_id: wolfram-alpha\n",
|
" provider_id: wolfram-alpha\n",
|
||||||
" toolgroup_id: builtin::wolfram_alpha\n",
|
" toolgroup_id: builtin::wolfram_alpha\n",
|
||||||
"vector_dbs: <span style=\"font-weight: bold\">[]</span>\n",
|
"vector_dbs: <span style=\"font-weight: bold\">[]</span>\n",
|
||||||
|
@ -617,9 +610,6 @@
|
||||||
" provider_id: tavily-search\n",
|
" provider_id: tavily-search\n",
|
||||||
" provider_type: remote::tavily-search\n",
|
" provider_type: remote::tavily-search\n",
|
||||||
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||||
" provider_id: code-interpreter\n",
|
|
||||||
" provider_type: inlin\u001b[1;92me::c\u001b[0mode-interpreter\n",
|
|
||||||
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
|
||||||
" provider_id: rag-runtime\n",
|
" provider_id: rag-runtime\n",
|
||||||
" provider_type: inline::rag-runtime\n",
|
" provider_type: inline::rag-runtime\n",
|
||||||
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||||
|
@ -658,10 +648,6 @@
|
||||||
" toolgroup_id: builtin::rag\n",
|
" toolgroup_id: builtin::rag\n",
|
||||||
"- args: null\n",
|
"- args: null\n",
|
||||||
" mcp_endpoint: null\n",
|
" mcp_endpoint: null\n",
|
||||||
" provider_id: code-interpreter\n",
|
|
||||||
" toolgroup_id: builtin::code_interpreter\n",
|
|
||||||
"- args: null\n",
|
|
||||||
" mcp_endpoint: null\n",
|
|
||||||
" provider_id: wolfram-alpha\n",
|
" provider_id: wolfram-alpha\n",
|
||||||
" toolgroup_id: builtin::wolfram_alpha\n",
|
" toolgroup_id: builtin::wolfram_alpha\n",
|
||||||
"vector_dbs: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
"vector_dbs: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||||
|
|
|
@ -840,7 +840,6 @@
|
||||||
" \"memory_optimizations.rst\",\n",
|
" \"memory_optimizations.rst\",\n",
|
||||||
" \"chat.rst\",\n",
|
" \"chat.rst\",\n",
|
||||||
" \"llama3.rst\",\n",
|
" \"llama3.rst\",\n",
|
||||||
" \"datasets.rst\",\n",
|
|
||||||
" \"qat_finetune.rst\",\n",
|
" \"qat_finetune.rst\",\n",
|
||||||
" \"lora_finetune.rst\",\n",
|
" \"lora_finetune.rst\",\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
|
@ -1586,7 +1585,6 @@
|
||||||
" \"memory_optimizations.rst\",\n",
|
" \"memory_optimizations.rst\",\n",
|
||||||
" \"chat.rst\",\n",
|
" \"chat.rst\",\n",
|
||||||
" \"llama3.rst\",\n",
|
" \"llama3.rst\",\n",
|
||||||
" \"datasets.rst\",\n",
|
|
||||||
" \"qat_finetune.rst\",\n",
|
" \"qat_finetune.rst\",\n",
|
||||||
" \"lora_finetune.rst\",\n",
|
" \"lora_finetune.rst\",\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
|
|
|
@ -44,7 +44,7 @@ def main(output_dir: str):
|
||||||
if return_type_errors:
|
if return_type_errors:
|
||||||
print("\nAPI Method Return Type Validation Errors:\n")
|
print("\nAPI Method Return Type Validation Errors:\n")
|
||||||
for error in return_type_errors:
|
for error in return_type_errors:
|
||||||
print(error)
|
print(error, file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
now = str(datetime.now())
|
now = str(datetime.now())
|
||||||
print(
|
print(
|
||||||
|
|
|
@ -759,7 +759,7 @@ class Generator:
|
||||||
)
|
)
|
||||||
|
|
||||||
return Operation(
|
return Operation(
|
||||||
tags=[op.defining_class.__name__],
|
tags=[getattr(op.defining_class, "API_NAMESPACE", op.defining_class.__name__)],
|
||||||
summary=None,
|
summary=None,
|
||||||
# summary=doc_string.short_description,
|
# summary=doc_string.short_description,
|
||||||
description=description,
|
description=description,
|
||||||
|
@ -805,6 +805,8 @@ class Generator:
|
||||||
operation_tags: List[Tag] = []
|
operation_tags: List[Tag] = []
|
||||||
for cls in endpoint_classes:
|
for cls in endpoint_classes:
|
||||||
doc_string = parse_type(cls)
|
doc_string = parse_type(cls)
|
||||||
|
if hasattr(cls, "API_NAMESPACE") and cls.API_NAMESPACE != cls.__name__:
|
||||||
|
continue
|
||||||
operation_tags.append(
|
operation_tags.append(
|
||||||
Tag(
|
Tag(
|
||||||
name=cls.__name__,
|
name=cls.__name__,
|
||||||
|
|
|
@ -174,14 +174,64 @@ def _validate_list_parameters_contain_data(method) -> str | None:
|
||||||
return "does not have a mandatory data attribute containing the list of objects"
|
return "does not have a mandatory data attribute containing the list of objects"
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_has_ellipsis(method) -> str | None:
|
||||||
|
source = inspect.getsource(method)
|
||||||
|
if "..." not in source and not "NotImplementedError" in source:
|
||||||
|
return "does not contain ellipsis (...) in its implementation"
|
||||||
|
|
||||||
|
def _validate_has_return_in_docstring(method) -> str | None:
|
||||||
|
source = inspect.getsource(method)
|
||||||
|
return_type = method.__annotations__.get('return')
|
||||||
|
if return_type is not None and return_type != type(None) and ":returns:" not in source:
|
||||||
|
return "does not have a ':returns:' in its docstring"
|
||||||
|
|
||||||
|
def _validate_has_params_in_docstring(method) -> str | None:
|
||||||
|
source = inspect.getsource(method)
|
||||||
|
sig = inspect.signature(method)
|
||||||
|
# Only check if the method has more than one parameter
|
||||||
|
if len(sig.parameters) > 1 and ":param" not in source:
|
||||||
|
return "does not have a ':param' in its docstring"
|
||||||
|
|
||||||
|
def _validate_has_no_return_none_in_docstring(method) -> str | None:
|
||||||
|
source = inspect.getsource(method)
|
||||||
|
return_type = method.__annotations__.get('return')
|
||||||
|
if return_type is None and ":returns: None" in source:
|
||||||
|
return "has a ':returns: None' in its docstring which is redundant for None-returning functions"
|
||||||
|
|
||||||
|
def _validate_docstring_lines_end_with_dot(method) -> str | None:
|
||||||
|
docstring = inspect.getdoc(method)
|
||||||
|
if docstring is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
lines = docstring.split('\n')
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if line and not any(line.endswith(char) for char in '.:{}[]()",'):
|
||||||
|
return f"docstring line '{line}' does not end with a valid character: . : {{ }} [ ] ( ) , \""
|
||||||
|
|
||||||
_VALIDATORS = {
|
_VALIDATORS = {
|
||||||
"GET": [
|
"GET": [
|
||||||
_validate_api_method_return_type,
|
_validate_api_method_return_type,
|
||||||
_validate_list_parameters_contain_data,
|
_validate_list_parameters_contain_data,
|
||||||
_validate_api_method_doesnt_return_list,
|
_validate_api_method_doesnt_return_list,
|
||||||
|
_validate_has_ellipsis,
|
||||||
|
_validate_has_return_in_docstring,
|
||||||
|
_validate_has_params_in_docstring,
|
||||||
|
_validate_docstring_lines_end_with_dot,
|
||||||
],
|
],
|
||||||
"DELETE": [
|
"DELETE": [
|
||||||
_validate_api_delete_method_returns_none,
|
_validate_api_delete_method_returns_none,
|
||||||
|
_validate_has_ellipsis,
|
||||||
|
_validate_has_return_in_docstring,
|
||||||
|
_validate_has_params_in_docstring,
|
||||||
|
_validate_has_no_return_none_in_docstring
|
||||||
|
],
|
||||||
|
"POST": [
|
||||||
|
_validate_has_ellipsis,
|
||||||
|
_validate_has_return_in_docstring,
|
||||||
|
_validate_has_params_in_docstring,
|
||||||
|
_validate_has_no_return_none_in_docstring,
|
||||||
|
_validate_docstring_lines_end_with_dot,
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,7 @@ chunks = [
|
||||||
"mime_type": "text/plain",
|
"mime_type": "text/plain",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"document_id": "doc1",
|
"document_id": "doc1",
|
||||||
|
"author": "Jane Doe",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
@ -98,6 +99,17 @@ results = client.tool_runtime.rag_tool.query(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can configure how the RAG tool adds metadata to the context if you find it useful for your application. Simply add:
|
||||||
|
```python
|
||||||
|
# Query documents
|
||||||
|
results = client.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[vector_db_id],
|
||||||
|
content="What do you know about...",
|
||||||
|
query_config={
|
||||||
|
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
```
|
||||||
### Building RAG-Enhanced Agents
|
### Building RAG-Enhanced Agents
|
||||||
|
|
||||||
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
||||||
|
@ -115,6 +127,12 @@ agent = Agent(
|
||||||
"name": "builtin::rag/knowledge_search",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {
|
"args": {
|
||||||
"vector_db_ids": [vector_db_id],
|
"vector_db_ids": [vector_db_id],
|
||||||
|
# Defaults
|
||||||
|
"query_config": {
|
||||||
|
"chunk_size_in_tokens": 512,
|
||||||
|
"chunk_overlap_in_tokens": 0,
|
||||||
|
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -165,34 +165,6 @@ all_tools = client.tools.list_tools()
|
||||||
group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
||||||
```
|
```
|
||||||
|
|
||||||
## Simple Example: Using an Agent with the Code-Interpreter Tool
|
|
||||||
|
|
||||||
```python
|
|
||||||
from llama_stack_client import Agent
|
|
||||||
|
|
||||||
# Instantiate the AI agent with the given configuration
|
|
||||||
agent = Agent(
|
|
||||||
client,
|
|
||||||
name="code-interpreter",
|
|
||||||
description="A code interpreter agent for executing Python code snippets",
|
|
||||||
instructions="""
|
|
||||||
You are a highly reliable, concise, and precise assistant.
|
|
||||||
Always show the generated code, never generate your own code, and never anticipate results.
|
|
||||||
""",
|
|
||||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
|
||||||
tools=["builtin::code_interpreter"],
|
|
||||||
max_infer_iters=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start a session
|
|
||||||
session_id = agent.create_session("tool_session")
|
|
||||||
|
|
||||||
# Send a query to the AI agent for code execution
|
|
||||||
response = agent.create_turn(
|
|
||||||
messages=[{"role": "user", "content": "Run this code: print(3 ** 4 - 5 * 2)"}],
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
## Simple Example 2: Using an Agent with the Web Search Tool
|
## Simple Example 2: Using an Agent with the Web Search Tool
|
||||||
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
|
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
|
||||||
2. [Optional] Provide the API key directly to the Llama Stack server
|
2. [Optional] Provide the API key directly to the Llama Stack server
|
||||||
|
|
|
@ -110,6 +110,8 @@ html_theme_options = {
|
||||||
"canonical_url": "https://github.com/meta-llama/llama-stack",
|
"canonical_url": "https://github.com/meta-llama/llama-stack",
|
||||||
"collapse_navigation": False,
|
"collapse_navigation": False,
|
||||||
# "style_nav_header_background": "#c3c9d4",
|
# "style_nav_header_background": "#c3c9d4",
|
||||||
|
'display_version': True,
|
||||||
|
'version_selector': True,
|
||||||
}
|
}
|
||||||
|
|
||||||
default_dark_mode = False
|
default_dark_mode = False
|
||||||
|
|
|
@ -6,7 +6,7 @@ This guide will walk you through the process of adding a new API provider to Lla
|
||||||
- Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.)
|
- Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.)
|
||||||
- Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally.
|
- Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally.
|
||||||
- Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary.
|
- Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary.
|
||||||
- Update any distribution {repopath}`Templates::llama_stack/templates/` build.yaml and run.yaml files if they should include your provider by default. Run {repopath}`./scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation.
|
- Update any distribution {repopath}`Templates::llama_stack/templates/` `build.yaml` and `run.yaml` files if they should include your provider by default. Run {repopath}`./scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation.
|
||||||
|
|
||||||
|
|
||||||
Here are some example PRs to help you get started:
|
Here are some example PRs to help you get started:
|
||||||
|
@ -33,6 +33,7 @@ Note that each provider's `sample_run_config()` method (in the configuration cla
|
||||||
|
|
||||||
Unit tests are located in {repopath}`tests/unit`. Provider-specific unit tests are located in {repopath}`tests/unit/providers`. These tests are all run automatically as part of the CI process.
|
Unit tests are located in {repopath}`tests/unit`. Provider-specific unit tests are located in {repopath}`tests/unit/providers`. These tests are all run automatically as part of the CI process.
|
||||||
|
|
||||||
|
Consult {repopath}`tests/unit/README.md` for more details on how to run the tests manually.
|
||||||
|
|
||||||
### 3. Additional end-to-end testing
|
### 3. Additional end-to-end testing
|
||||||
|
|
||||||
|
|
|
@ -178,7 +178,7 @@ image_name: ollama
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
||||||
# If some providers are external, you can specify the path to the implementation
|
# If some providers are external, you can specify the path to the implementation
|
||||||
external_providers_dir: /etc/llama-stack/providers.d
|
external_providers_dir: ~/.llama/providers.d
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -206,7 +206,7 @@ distribution_spec:
|
||||||
image_type: container
|
image_type: container
|
||||||
image_name: ci-test
|
image_name: ci-test
|
||||||
# Path to external provider implementations
|
# Path to external provider implementations
|
||||||
external_providers_dir: /etc/llama-stack/providers.d
|
external_providers_dir: ~/.llama/providers.d
|
||||||
```
|
```
|
||||||
|
|
||||||
Here's an example for a custom Ollama provider:
|
Here's an example for a custom Ollama provider:
|
||||||
|
@ -271,7 +271,7 @@ Now, let's start the Llama Stack Distribution Server. You will need the YAML con
|
||||||
|
|
||||||
```
|
```
|
||||||
llama stack run -h
|
llama stack run -h
|
||||||
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] [--tls-certfile TLS_CERTFILE]
|
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] [--tls-certfile TLS_CERTFILE]
|
||||||
[--image-type {conda,container,venv}]
|
[--image-type {conda,container,venv}]
|
||||||
config
|
config
|
||||||
|
|
||||||
|
@ -285,7 +285,6 @@ options:
|
||||||
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
|
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
|
||||||
--image-name IMAGE_NAME
|
--image-name IMAGE_NAME
|
||||||
Name of the image to run. Defaults to the current environment (default: None)
|
Name of the image to run. Defaults to the current environment (default: None)
|
||||||
--disable-ipv6 Disable IPv6 support (default: False)
|
|
||||||
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
|
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
|
||||||
--tls-keyfile TLS_KEYFILE
|
--tls-keyfile TLS_KEYFILE
|
||||||
Path to TLS key file for HTTPS (default: None)
|
Path to TLS key file for HTTPS (default: None)
|
||||||
|
|
|
@ -172,7 +172,7 @@ spec:
|
||||||
- name: llama-stack
|
- name: llama-stack
|
||||||
image: localhost/llama-stack-run-k8s:latest
|
image: localhost/llama-stack-run-k8s:latest
|
||||||
imagePullPolicy: IfNotPresent
|
imagePullPolicy: IfNotPresent
|
||||||
command: ["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]
|
command: ["python", "-m", "llama_stack.distribution.server.server", "--config", "/app/config.yaml"]
|
||||||
ports:
|
ports:
|
||||||
- containerPort: 5000
|
- containerPort: 5000
|
||||||
volumeMounts:
|
volumeMounts:
|
||||||
|
|
|
@ -18,7 +18,7 @@ The `llamastack/distribution-watsonx` distribution consists of the following pro
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::watsonx` |
|
| inference | `remote::watsonx`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
@ -70,7 +70,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-watsonx \
|
llamastack/distribution-watsonx \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||||
|
|
|
@ -52,7 +52,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-cerebras \
|
llamastack/distribution-cerebras \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -23,7 +23,7 @@ The `llamastack/distribution-dell` distribution consists of the following provid
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -155,7 +155,7 @@ docker run \
|
||||||
-v $HOME/.llama:/root/.llama \
|
-v $HOME/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-dell \
|
llamastack/distribution-dell \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env DEH_URL=$DEH_URL \
|
--env DEH_URL=$DEH_URL \
|
||||||
|
|
|
@ -144,7 +144,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-nvidia \
|
llamastack/distribution-nvidia \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||||
```
|
```
|
||||||
|
|
|
@ -19,6 +19,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::ollama` |
|
| inference | `remote::ollama` |
|
||||||
|
| post_training | `inline::huggingface` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
@ -97,7 +98,7 @@ docker run \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-ollama \
|
llamastack/distribution-ollama \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||||
|
|
|
@ -233,7 +233,7 @@ docker run \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-remote-vllm \
|
llamastack/distribution-remote-vllm \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1
|
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1
|
||||||
|
@ -255,7 +255,7 @@ docker run \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-remote-vllm \
|
llamastack/distribution-remote-vllm \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \
|
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \
|
||||||
|
|
|
@ -16,10 +16,10 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
||||||
| API | Provider(s) |
|
| API | Provider(s) |
|
||||||
|-----|-------------|
|
|-----|-------------|
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| inference | `remote::sambanova` |
|
| inference | `remote::sambanova`, `inline::sentence-transformers` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `inline::llama-guard` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,22 +28,22 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||||
- `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``)
|
- `SAMBANOVA_API_KEY`: SambaNova API Key (default: ``)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
- `sambanova/Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
- `Meta-Llama-3.1-70B-Instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
- `sambanova/Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
- `Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
- `sambanova/Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
- `Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
- `sambanova/Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
- `Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
- `sambanova/Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
- `sambanova/Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
- `sambanova/Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
- `sambanova/Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||||
- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
- `sambanova/Llama-4-Maverick-17B-128E-Instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
|
||||||
- `Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
- `sambanova/Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -117,7 +117,7 @@ docker run \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||||
llamastack/distribution-tgi \
|
llamastack/distribution-tgi \
|
||||||
--yaml-config /root/my-run.yaml \
|
--config /root/my-run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \
|
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \
|
||||||
|
|
|
@ -42,7 +42,7 @@ powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | ie
|
||||||
Setup your virtual environment.
|
Setup your virtual environment.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv venv --python 3.10
|
uv sync --python 3.10
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
```
|
```
|
||||||
## Step 2: Run Llama Stack
|
## Step 2: Run Llama Stack
|
||||||
|
@ -445,7 +445,6 @@ from llama_stack_client import LlamaStackClient
|
||||||
from llama_stack_client import Agent, AgentEventLogger
|
from llama_stack_client import Agent, AgentEventLogger
|
||||||
from llama_stack_client.types import Document
|
from llama_stack_client.types import Document
|
||||||
import uuid
|
import uuid
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
client = LlamaStackClient(base_url="http://localhost:8321")
|
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
@ -463,7 +462,6 @@ urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
"chat.rst",
|
"chat.rst",
|
||||||
"llama3.rst",
|
"llama3.rst",
|
||||||
"datasets.rst",
|
|
||||||
"qat_finetune.rst",
|
"qat_finetune.rst",
|
||||||
"lora_finetune.rst",
|
"lora_finetune.rst",
|
||||||
]
|
]
|
||||||
|
|
|
@ -10,7 +10,7 @@ Llama Stack supports external providers that live outside of the main codebase.
|
||||||
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
|
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
external_providers_dir: /etc/llama-stack/providers.d/
|
external_providers_dir: ~/.llama/providers.d/
|
||||||
```
|
```
|
||||||
|
|
||||||
## Directory Structure
|
## Directory Structure
|
||||||
|
@ -53,7 +53,7 @@ Here's a list of known external providers that you can use with Llama Stack:
|
||||||
| Name | Description | API | Type | Repository |
|
| Name | Description | API | Type | Repository |
|
||||||
|------|-------------|-----|------|------------|
|
|------|-------------|-----|------|------------|
|
||||||
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||||
| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) |
|
| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Inline **and** Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) |
|
||||||
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
|
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
|
||||||
| TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) |
|
| TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) |
|
||||||
|
|
||||||
|
@ -182,7 +182,7 @@ dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
|
||||||
3. Create the provider specification:
|
3. Create the provider specification:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml
|
# ~/.llama/providers.d/remote/inference/custom_ollama.yaml
|
||||||
adapter:
|
adapter:
|
||||||
adapter_type: custom_ollama
|
adapter_type: custom_ollama
|
||||||
pip_packages: ["ollama", "aiohttp"]
|
pip_packages: ["ollama", "aiohttp"]
|
||||||
|
@ -201,7 +201,7 @@ uv pip install -e .
|
||||||
5. Configure Llama Stack to use external providers:
|
5. Configure Llama Stack to use external providers:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
external_providers_dir: /etc/llama-stack/providers.d/
|
external_providers_dir: ~/.llama/providers.d/
|
||||||
```
|
```
|
||||||
|
|
||||||
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
|
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
|
||||||
|
|
|
@ -253,8 +253,6 @@ llama-stack-client toolgroups list
|
||||||
+---------------------------+------------------+------+---------------+
|
+---------------------------+------------------+------+---------------+
|
||||||
| identifier | provider_id | args | mcp_endpoint |
|
| identifier | provider_id | args | mcp_endpoint |
|
||||||
+===========================+==================+======+===============+
|
+===========================+==================+======+===============+
|
||||||
| builtin::code_interpreter | code-interpreter | None | None |
|
|
||||||
+---------------------------+------------------+------+---------------+
|
|
||||||
| builtin::rag | rag-runtime | None | None |
|
| builtin::rag | rag-runtime | None | None |
|
||||||
+---------------------------+------------------+------+---------------+
|
+---------------------------+------------------+------+---------------+
|
||||||
| builtin::websearch | tavily-search | None | None |
|
| builtin::websearch | tavily-search | None | None |
|
||||||
|
|
61
install.sh
61
install.sh
|
@ -38,6 +38,67 @@ wait_for_service() {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
cat << EOF
|
||||||
|
📚 Llama-Stack Deployment Script
|
||||||
|
|
||||||
|
Description:
|
||||||
|
This script sets up and deploys Llama-Stack with Ollama integration in containers.
|
||||||
|
It handles both Docker and Podman runtimes and includes automatic platform detection.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
$(basename "$0") [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
-p, --port PORT Server port for Llama-Stack (default: ${PORT})
|
||||||
|
-o, --ollama-port PORT Ollama service port (default: ${OLLAMA_PORT})
|
||||||
|
-m, --model MODEL Model alias to use (default: ${MODEL_ALIAS})
|
||||||
|
-i, --image IMAGE Server image (default: ${SERVER_IMAGE})
|
||||||
|
-t, --timeout SECONDS Service wait timeout in seconds (default: ${WAIT_TIMEOUT})
|
||||||
|
-h, --help Show this help message
|
||||||
|
|
||||||
|
For more information:
|
||||||
|
Documentation: https://llama-stack.readthedocs.io/
|
||||||
|
GitHub: https://github.com/meta-llama/llama-stack
|
||||||
|
|
||||||
|
Report issues:
|
||||||
|
https://github.com/meta-llama/llama-stack/issues
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
-h|--help)
|
||||||
|
usage
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
-p|--port)
|
||||||
|
PORT="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-o|--ollama-port)
|
||||||
|
OLLAMA_PORT="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-m|--model)
|
||||||
|
MODEL_ALIAS="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-i|--image)
|
||||||
|
SERVER_IMAGE="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-t|--timeout)
|
||||||
|
WAIT_TIMEOUT="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
die "Unknown option: $1"
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
if command -v docker &> /dev/null; then
|
if command -v docker &> /dev/null; then
|
||||||
ENGINE="docker"
|
ENGINE="docker"
|
||||||
elif command -v podman &> /dev/null; then
|
elif command -v podman &> /dev/null; then
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import sys
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -12,6 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||||
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -29,12 +31,20 @@ from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
from .openai_responses import (
|
from .openai_responses import (
|
||||||
OpenAIResponseInputMessage,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from enum import StrEnum
|
||||||
|
else:
|
||||||
|
|
||||||
|
class StrEnum(str, Enum):
|
||||||
|
"""Backport of StrEnum for Python 3.10 and below."""
|
||||||
|
|
||||||
|
|
||||||
class Attachment(BaseModel):
|
class Attachment(BaseModel):
|
||||||
"""An attachment to an agent turn.
|
"""An attachment to an agent turn.
|
||||||
|
@ -73,7 +83,7 @@ class StepCommon(BaseModel):
|
||||||
completed_at: datetime | None = None
|
completed_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
class StepType(Enum):
|
class StepType(StrEnum):
|
||||||
"""Type of the step in an agent turn.
|
"""Type of the step in an agent turn.
|
||||||
|
|
||||||
:cvar inference: The step is an inference step that calls an LLM.
|
:cvar inference: The step is an inference step that calls an LLM.
|
||||||
|
@ -97,7 +107,7 @@ class InferenceStep(StepCommon):
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
step_type: Literal[StepType.inference] = StepType.inference
|
||||||
model_response: CompletionMessage
|
model_response: CompletionMessage
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,7 +119,7 @@ class ToolExecutionStep(StepCommon):
|
||||||
:param tool_responses: The tool responses from the tool calls.
|
:param tool_responses: The tool responses from the tool calls.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
|
||||||
tool_calls: list[ToolCall]
|
tool_calls: list[ToolCall]
|
||||||
tool_responses: list[ToolResponse]
|
tool_responses: list[ToolResponse]
|
||||||
|
|
||||||
|
@ -121,7 +131,7 @@ class ShieldCallStep(StepCommon):
|
||||||
:param violation: The violation from the shield call.
|
:param violation: The violation from the shield call.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
step_type: Literal[StepType.shield_call] = StepType.shield_call
|
||||||
violation: SafetyViolation | None
|
violation: SafetyViolation | None
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,7 +143,7 @@ class MemoryRetrievalStep(StepCommon):
|
||||||
:param inserted_context: The context retrieved from the vector databases.
|
:param inserted_context: The context retrieved from the vector databases.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
|
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
|
||||||
# TODO: should this be List[str]?
|
# TODO: should this be List[str]?
|
||||||
vector_db_ids: str
|
vector_db_ids: str
|
||||||
inserted_context: InterleavedContent
|
inserted_context: InterleavedContent
|
||||||
|
@ -154,7 +164,7 @@ class Turn(BaseModel):
|
||||||
input_messages: list[UserMessage | ToolResponseMessage]
|
input_messages: list[UserMessage | ToolResponseMessage]
|
||||||
steps: list[Step]
|
steps: list[Step]
|
||||||
output_message: CompletionMessage
|
output_message: CompletionMessage
|
||||||
output_attachments: list[Attachment] | None = Field(default_factory=list)
|
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
|
||||||
|
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
completed_at: datetime | None = None
|
completed_at: datetime | None = None
|
||||||
|
@ -182,10 +192,10 @@ register_schema(AgentToolGroup, name="AgentTool")
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
|
||||||
input_shields: list[str] | None = Field(default_factory=list)
|
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||||
output_shields: list[str] | None = Field(default_factory=list)
|
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
|
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||||
client_tools: list[ToolDef] | None = Field(default_factory=list)
|
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
||||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||||
tool_config: ToolConfig | None = Field(default=None)
|
tool_config: ToolConfig | None = Field(default=None)
|
||||||
|
@ -232,21 +242,11 @@ class Agent(BaseModel):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListAgentsResponse(BaseModel):
|
|
||||||
data: list[Agent]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListAgentSessionsResponse(BaseModel):
|
|
||||||
data: list[Session]
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentTurnResponseEventType(Enum):
|
class AgentTurnResponseEventType(StrEnum):
|
||||||
step_start = "step_start"
|
step_start = "step_start"
|
||||||
step_complete = "step_complete"
|
step_complete = "step_complete"
|
||||||
step_progress = "step_progress"
|
step_progress = "step_progress"
|
||||||
|
@ -258,15 +258,15 @@ class AgentTurnResponseEventType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
metadata: dict[str, Any] | None = Field(default_factory=dict)
|
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value
|
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
step_details: Step
|
step_details: Step
|
||||||
|
@ -276,7 +276,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value
|
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
|
||||||
step_type: StepType
|
step_type: StepType
|
||||||
step_id: str
|
step_id: str
|
||||||
|
|
||||||
|
@ -285,21 +285,19 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value
|
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
|
||||||
turn_id: str
|
turn_id: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value
|
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
|
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
|
||||||
AgentTurnResponseEventType.turn_awaiting_input.value
|
|
||||||
)
|
|
||||||
turn: Turn
|
turn: Turn
|
||||||
|
|
||||||
|
|
||||||
|
@ -341,7 +339,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
messages: list[UserMessage | ToolResponseMessage]
|
messages: list[UserMessage | ToolResponseMessage]
|
||||||
|
|
||||||
documents: list[Document] | None = None
|
documents: list[Document] | None = None
|
||||||
toolgroups: list[AgentToolGroup] | None = None
|
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||||
|
|
||||||
stream: bool | None = False
|
stream: bool | None = False
|
||||||
tool_config: ToolConfig | None = None
|
tool_config: ToolConfig | None = None
|
||||||
|
@ -415,8 +413,9 @@ class Agents(Protocol):
|
||||||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
||||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
||||||
:returns: If stream=False, returns a Turn object.
|
:returns: If stream=False, returns a Turn object.
|
||||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
|
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
|
||||||
"""
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||||
|
@ -510,6 +509,7 @@ class Agents(Protocol):
|
||||||
:param session_id: The ID of the session to get.
|
:param session_id: The ID of the session to get.
|
||||||
:param agent_id: The ID of the agent to get the session for.
|
:param agent_id: The ID of the agent to get the session for.
|
||||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
||||||
|
:returns: A Session.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -519,7 +519,7 @@ class Agents(Protocol):
|
||||||
session_id: str,
|
session_id: str,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Delete an agent session by its ID.
|
"""Delete an agent session by its ID and its associated turns.
|
||||||
|
|
||||||
:param session_id: The ID of the session to delete.
|
:param session_id: The ID of the session to delete.
|
||||||
:param agent_id: The ID of the agent to delete the session for.
|
:param agent_id: The ID of the agent to delete the session for.
|
||||||
|
@ -531,17 +531,19 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Delete an agent by its ID.
|
"""Delete an agent by its ID and its associated sessions and turns.
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to delete.
|
:param agent_id: The ID of the agent to delete.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/agents", method="GET")
|
@webmethod(route="/agents", method="GET")
|
||||||
async def list_agents(self) -> ListAgentsResponse:
|
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||||
"""List all agents.
|
"""List all agents.
|
||||||
|
|
||||||
:returns: A ListAgentsResponse.
|
:param start_index: The index to start the pagination from.
|
||||||
|
:param limit: The number of agents to return.
|
||||||
|
:returns: A PaginatedResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -558,11 +560,15 @@ class Agents(Protocol):
|
||||||
async def list_agent_sessions(
|
async def list_agent_sessions(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> ListAgentSessionsResponse:
|
start_index: int | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> PaginatedResponse:
|
||||||
"""List all session(s) of a given agent.
|
"""List all session(s) of a given agent.
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to list sessions for.
|
:param agent_id: The ID of the agent to list sessions for.
|
||||||
:returns: A ListAgentSessionsResponse.
|
:param start_index: The index to start the pagination from.
|
||||||
|
:param limit: The number of sessions to return.
|
||||||
|
:returns: A PaginatedResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -588,7 +594,7 @@ class Agents(Protocol):
|
||||||
@webmethod(route="/openai/v1/responses", method="POST")
|
@webmethod(route="/openai/v1/responses", method="POST")
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
self,
|
self,
|
||||||
input: str | list[OpenAIResponseInputMessage],
|
input: str | list[OpenAIResponseInput],
|
||||||
model: str,
|
model: str,
|
||||||
previous_response_id: str | None = None,
|
previous_response_id: str | None = None,
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
|
@ -601,4 +607,6 @@ class Agents(Protocol):
|
||||||
:param input: Input message(s) to create the response.
|
:param input: Input message(s) to create the response.
|
||||||
:param model: The underlying LLM used for completions.
|
:param model: The underlying LLM used for completions.
|
||||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||||
|
:returns: An OpenAIResponseObject.
|
||||||
"""
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -17,6 +17,28 @@ class OpenAIResponseError(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputMessageContentText(BaseModel):
|
||||||
|
text: str
|
||||||
|
type: Literal["input_text"] = "input_text"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||||
|
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||||
|
type: Literal["input_image"] = "input_image"
|
||||||
|
# TODO: handle file_id
|
||||||
|
image_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: handle file content types
|
||||||
|
OpenAIResponseInputMessageContent = Annotated[
|
||||||
|
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
|
@ -31,13 +53,22 @@ register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMe
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseOutputMessage(BaseModel):
|
class OpenAIResponseMessage(BaseModel):
|
||||||
id: str
|
"""
|
||||||
content: list[OpenAIResponseOutputMessageContent]
|
Corresponds to the various Message types in the Responses API.
|
||||||
role: Literal["assistant"] = "assistant"
|
They are all under one type because the Responses API gives them all
|
||||||
status: str
|
the same "type" value, and there is no way to tell them apart in certain
|
||||||
|
scenarios.
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
|
||||||
|
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||||
type: Literal["message"] = "message"
|
type: Literal["message"] = "message"
|
||||||
|
|
||||||
|
# The fields below are not used in all scenarios, but are required in others.
|
||||||
|
id: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
|
@ -46,8 +77,18 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
type: Literal["web_search_call"] = "web_search_call"
|
type: Literal["web_search_call"] = "web_search_call"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||||
|
arguments: str
|
||||||
|
call_id: str
|
||||||
|
name: str
|
||||||
|
type: Literal["function_call"] = "function_call"
|
||||||
|
id: str
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseOutput = Annotated[
|
OpenAIResponseOutput = Annotated[
|
||||||
OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
@ -90,32 +131,29 @@ register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseInputMessageContentText(BaseModel):
|
class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
|
||||||
text: str
|
"""
|
||||||
type: Literal["input_text"] = "input_text"
|
This represents the output of a function call that gets passed back to the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
call_id: str
|
||||||
|
output: str
|
||||||
|
type: Literal["function_call_output"] = "function_call_output"
|
||||||
|
id: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
OpenAIResponseInput = Annotated[
|
||||||
class OpenAIResponseInputMessageContentImage(BaseModel):
|
# Responses API allows output messages to be passed in as input
|
||||||
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
type: Literal["input_image"] = "input_image"
|
| OpenAIResponseOutputMessageFunctionToolCall
|
||||||
# TODO: handle file_id
|
| OpenAIResponseInputFunctionToolCallOutput
|
||||||
image_url: str | None = None
|
|
|
||||||
|
# Fallback to the generic message type as a last resort
|
||||||
|
OpenAIResponseMessage,
|
||||||
# TODO: handle file content types
|
Field(union_mode="left_to_right"),
|
||||||
OpenAIResponseInputMessageContent = Annotated[
|
|
||||||
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
register_schema(OpenAIResponseInput, name="OpenAIResponseInput")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class OpenAIResponseInputMessage(BaseModel):
|
|
||||||
content: str | list[OpenAIResponseInputMessageContent]
|
|
||||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
|
||||||
type: Literal["message"] | None = "message"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -126,8 +164,35 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||||
# TODO: add user_location
|
# TODO: add user_location
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputToolFunction(BaseModel):
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
parameters: dict[str, Any] | None
|
||||||
|
strict: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileSearchRankingOptions(BaseModel):
|
||||||
|
ranker: str | None = None
|
||||||
|
score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||||
|
type: Literal["file_search"] = "file_search"
|
||||||
|
vector_store_id: list[str]
|
||||||
|
ranking_options: FileSearchRankingOptions | None = None
|
||||||
|
# TODO: add filters
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseInputTool = Annotated[
|
OpenAIResponseInputTool = Annotated[
|
||||||
OpenAIResponseInputToolWebSearch,
|
OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponseInputItemList(BaseModel):
|
||||||
|
data: list[OpenAIResponseInput]
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
|
|
@ -38,7 +38,17 @@ class BatchInference(Protocol):
|
||||||
sampling_params: SamplingParams | None = None,
|
sampling_params: SamplingParams | None = None,
|
||||||
response_format: ResponseFormat | None = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Job: ...
|
) -> Job:
|
||||||
|
"""Generate completions for a batch of content.
|
||||||
|
|
||||||
|
:param model: The model to use for the completion.
|
||||||
|
:param content_batch: The content to complete.
|
||||||
|
:param sampling_params: The sampling parameters to use for the completion.
|
||||||
|
:param response_format: The response format to use for the completion.
|
||||||
|
:param logprobs: The logprobs to use for the completion.
|
||||||
|
:returns: A job for the completion.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -52,4 +62,17 @@ class BatchInference(Protocol):
|
||||||
tool_prompt_format: ToolPromptFormat | None = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
response_format: ResponseFormat | None = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> Job: ...
|
) -> Job:
|
||||||
|
"""Generate chat completions for a batch of messages.
|
||||||
|
|
||||||
|
:param model: The model to use for the chat completion.
|
||||||
|
:param messages_batch: The messages to complete.
|
||||||
|
:param sampling_params: The sampling parameters to use for the completion.
|
||||||
|
:param tools: The tools to use for the chat completion.
|
||||||
|
:param tool_choice: The tool choice to use for the chat completion.
|
||||||
|
:param tool_prompt_format: The tool prompt format to use for the chat completion.
|
||||||
|
:param response_format: The response format to use for the chat completion.
|
||||||
|
:param logprobs: The logprobs to use for the chat completion.
|
||||||
|
:returns: A job for the chat completion.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -22,14 +22,14 @@ class CommonBenchmarkFields(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Benchmark(CommonBenchmarkFields, Resource):
|
class Benchmark(CommonBenchmarkFields, Resource):
|
||||||
type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value
|
type: Literal[ResourceType.benchmark] = ResourceType.benchmark
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def benchmark_id(self) -> str:
|
def benchmark_id(self) -> str:
|
||||||
return self.identifier
|
return self.identifier
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_benchmark_id(self) -> str:
|
def provider_benchmark_id(self) -> str | None:
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,13 +46,24 @@ class ListBenchmarksResponse(BaseModel):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Benchmarks(Protocol):
|
class Benchmarks(Protocol):
|
||||||
@webmethod(route="/eval/benchmarks", method="GET")
|
@webmethod(route="/eval/benchmarks", method="GET")
|
||||||
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
|
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||||
|
"""List all benchmarks.
|
||||||
|
|
||||||
|
:returns: A ListBenchmarksResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
|
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
|
||||||
async def get_benchmark(
|
async def get_benchmark(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
) -> Benchmark: ...
|
) -> Benchmark:
|
||||||
|
"""Get a benchmark by its ID.
|
||||||
|
|
||||||
|
:param benchmark_id: The ID of the benchmark to get.
|
||||||
|
:returns: A Benchmark.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks", method="POST")
|
@webmethod(route="/eval/benchmarks", method="POST")
|
||||||
async def register_benchmark(
|
async def register_benchmark(
|
||||||
|
@ -63,4 +74,14 @@ class Benchmarks(Protocol):
|
||||||
provider_benchmark_id: str | None = None,
|
provider_benchmark_id: str | None = None,
|
||||||
provider_id: str | None = None,
|
provider_id: str | None = None,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
"""Register a benchmark.
|
||||||
|
|
||||||
|
:param benchmark_id: The ID of the benchmark to register.
|
||||||
|
:param dataset_id: The ID of the dataset to use for the benchmark.
|
||||||
|
:param scoring_functions: The scoring functions to use for the benchmark.
|
||||||
|
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
|
||||||
|
:param provider_id: The ID of the provider to use for the benchmark.
|
||||||
|
:param metadata: The metadata to use for the benchmark.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -28,7 +28,7 @@ class _URLOrData(BaseModel):
|
||||||
|
|
||||||
url: URL | None = None
|
url: URL | None = None
|
||||||
# data is a base64 encoded string, hint with contentEncoding=base64
|
# data is a base64 encoded string, hint with contentEncoding=base64
|
||||||
data: str | None = Field(contentEncoding="base64", default=None)
|
data: str | None = Field(default=None, json_schema_extra={"contentEncoding": "base64"})
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -34,14 +34,21 @@ class DatasetIO(Protocol):
|
||||||
- limit: Number of items to return. If None or -1, returns all items.
|
- limit: Number of items to return. If None or -1, returns all items.
|
||||||
|
|
||||||
The response includes:
|
The response includes:
|
||||||
- data: List of items for the current page
|
- data: List of items for the current page.
|
||||||
- has_more: Whether there are more items available after this set
|
- has_more: Whether there are more items available after this set.
|
||||||
|
|
||||||
:param dataset_id: The ID of the dataset to get the rows from.
|
:param dataset_id: The ID of the dataset to get the rows from.
|
||||||
:param start_index: Index into dataset for the first row to get. Get all rows if None.
|
:param start_index: Index into dataset for the first row to get. Get all rows if None.
|
||||||
:param limit: The number of rows to get.
|
:param limit: The number of rows to get.
|
||||||
|
:returns: A PaginatedResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ...
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||||
|
"""Append rows to a dataset.
|
||||||
|
|
||||||
|
:param dataset_id: The ID of the dataset to append the rows to.
|
||||||
|
:param rows: The rows to append to the dataset.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -106,14 +106,14 @@ class CommonDatasetFields(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Dataset(CommonDatasetFields, Resource):
|
class Dataset(CommonDatasetFields, Resource):
|
||||||
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value
|
type: Literal[ResourceType.dataset] = ResourceType.dataset
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset_id(self) -> str:
|
def dataset_id(self) -> str:
|
||||||
return self.identifier
|
return self.identifier
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_dataset_id(self) -> str:
|
def provider_dataset_id(self) -> str | None:
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,7 +137,8 @@ class Datasets(Protocol):
|
||||||
"""
|
"""
|
||||||
Register a new dataset.
|
Register a new dataset.
|
||||||
|
|
||||||
:param purpose: The purpose of the dataset. One of
|
:param purpose: The purpose of the dataset.
|
||||||
|
One of:
|
||||||
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
||||||
{
|
{
|
||||||
"messages": [
|
"messages": [
|
||||||
|
@ -188,8 +189,9 @@ class Datasets(Protocol):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
:param metadata: The metadata for the dataset.
|
:param metadata: The metadata for the dataset.
|
||||||
- E.g. {"description": "My dataset"}
|
- E.g. {"description": "My dataset"}.
|
||||||
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
||||||
|
:returns: A Dataset.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -197,13 +199,29 @@ class Datasets(Protocol):
|
||||||
async def get_dataset(
|
async def get_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
) -> Dataset: ...
|
) -> Dataset:
|
||||||
|
"""Get a dataset by its ID.
|
||||||
|
|
||||||
|
:param dataset_id: The ID of the dataset to get.
|
||||||
|
:returns: A Dataset.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasets", method="GET")
|
@webmethod(route="/datasets", method="GET")
|
||||||
async def list_datasets(self) -> ListDatasetsResponse: ...
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
|
"""List all datasets.
|
||||||
|
|
||||||
|
:returns: A ListDatasetsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
||||||
async def unregister_dataset(
|
async def unregister_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
"""Unregister a dataset by its ID.
|
||||||
|
|
||||||
|
:param dataset_id: The ID of the dataset to unregister.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -93,8 +93,9 @@ class Eval(Protocol):
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||||
:param benchmark_config: The configuration for the benchmark.
|
:param benchmark_config: The configuration for the benchmark.
|
||||||
:return: The job that was created to run the evaluation.
|
:returns: The job that was created to run the evaluation.
|
||||||
"""
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||||
async def evaluate_rows(
|
async def evaluate_rows(
|
||||||
|
@ -110,8 +111,9 @@ class Eval(Protocol):
|
||||||
:param input_rows: The rows to evaluate.
|
:param input_rows: The rows to evaluate.
|
||||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||||
:param benchmark_config: The configuration for the benchmark.
|
:param benchmark_config: The configuration for the benchmark.
|
||||||
:return: EvaluateResponse object containing generations and scores
|
:returns: EvaluateResponse object containing generations and scores.
|
||||||
"""
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||||
|
@ -119,7 +121,7 @@ class Eval(Protocol):
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||||
:param job_id: The ID of the job to get the status of.
|
:param job_id: The ID of the job to get the status of.
|
||||||
:return: The status of the evaluationjob.
|
:returns: The status of the evaluation job.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -138,5 +140,6 @@ class Eval(Protocol):
|
||||||
|
|
||||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||||
:param job_id: The ID of the job to get the result of.
|
:param job_id: The ID of the job to get the result of.
|
||||||
:return: The result of the job.
|
:returns: The result of the job.
|
||||||
"""
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -91,10 +91,11 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
Create a new upload session for a file identified by a bucket and key.
|
Create a new upload session for a file identified by a bucket and key.
|
||||||
|
|
||||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
|
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-).
|
||||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||||
:param mime_type: MIME type of the file
|
:param mime_type: MIME type of the file.
|
||||||
:param size: File size in bytes
|
:param size: File size in bytes.
|
||||||
|
:returns: A FileUploadResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -107,7 +108,8 @@ class Files(Protocol):
|
||||||
Upload file content to an existing upload session.
|
Upload file content to an existing upload session.
|
||||||
On the server, request body will have the raw bytes that are uploaded.
|
On the server, request body will have the raw bytes that are uploaded.
|
||||||
|
|
||||||
:param upload_id: ID of the upload session
|
:param upload_id: ID of the upload session.
|
||||||
|
:returns: A FileResponse or None if the upload is not complete.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -117,9 +119,10 @@ class Files(Protocol):
|
||||||
upload_id: str,
|
upload_id: str,
|
||||||
) -> FileUploadResponse:
|
) -> FileUploadResponse:
|
||||||
"""
|
"""
|
||||||
Returns information about an existsing upload session
|
Returns information about an existsing upload session.
|
||||||
|
|
||||||
:param upload_id: ID of the upload session
|
:param upload_id: ID of the upload session.
|
||||||
|
:returns: A FileUploadResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -130,6 +133,9 @@ class Files(Protocol):
|
||||||
) -> ListBucketResponse:
|
) -> ListBucketResponse:
|
||||||
"""
|
"""
|
||||||
List all buckets.
|
List all buckets.
|
||||||
|
|
||||||
|
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||||
|
:returns: A ListBucketResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -141,7 +147,8 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
List all files in a bucket.
|
List all files in a bucket.
|
||||||
|
|
||||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||||
|
:returns: A ListFileResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -154,8 +161,9 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
Get a file info identified by a bucket and key.
|
Get a file info identified by a bucket and key.
|
||||||
|
|
||||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||||
|
:returns: A FileResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -168,7 +176,7 @@ class Files(Protocol):
|
||||||
"""
|
"""
|
||||||
Delete a file identified by a bucket and key.
|
Delete a file identified by a bucket and key.
|
||||||
|
|
||||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import sys
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
|
@ -35,6 +36,16 @@ register_schema(ToolCall)
|
||||||
register_schema(ToolParamDefinition)
|
register_schema(ToolParamDefinition)
|
||||||
register_schema(ToolDefinition)
|
register_schema(ToolDefinition)
|
||||||
|
|
||||||
|
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from enum import StrEnum
|
||||||
|
else:
|
||||||
|
|
||||||
|
class StrEnum(str, Enum):
|
||||||
|
"""Backport of StrEnum for Python 3.10 and below."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class GreedySamplingStrategy(BaseModel):
|
class GreedySamplingStrategy(BaseModel):
|
||||||
|
@ -187,7 +198,7 @@ class CompletionMessage(BaseModel):
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
stop_reason: StopReason
|
stop_reason: StopReason
|
||||||
tool_calls: list[ToolCall] | None = Field(default_factory=list)
|
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
|
||||||
Message = Annotated[
|
Message = Annotated[
|
||||||
|
@ -267,7 +278,7 @@ class ChatCompletionResponseEvent(BaseModel):
|
||||||
stop_reason: StopReason | None = None
|
stop_reason: StopReason | None = None
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormatType(Enum):
|
class ResponseFormatType(StrEnum):
|
||||||
"""Types of formats for structured (guided) decoding.
|
"""Types of formats for structured (guided) decoding.
|
||||||
|
|
||||||
:cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model.
|
:cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model.
|
||||||
|
@ -286,7 +297,7 @@ class JsonSchemaResponseFormat(BaseModel):
|
||||||
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
|
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
type: Literal[ResponseFormatType.json_schema] = ResponseFormatType.json_schema
|
||||||
json_schema: dict[str, Any]
|
json_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@ -298,7 +309,7 @@ class GrammarResponseFormat(BaseModel):
|
||||||
:param bnf: The BNF grammar specification the response should conform to
|
:param bnf: The BNF grammar specification the response should conform to
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
type: Literal[ResponseFormatType.grammar] = ResponseFormatType.grammar
|
||||||
bnf: dict[str, Any]
|
bnf: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@ -394,7 +405,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
messages: list[Message]
|
messages: list[Message]
|
||||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||||
|
|
||||||
tools: list[ToolDefinition] | None = Field(default_factory=list)
|
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
|
||||||
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
||||||
|
|
||||||
response_format: ResponseFormat | None = None
|
response_format: ResponseFormat | None = None
|
||||||
|
@ -567,14 +578,14 @@ class OpenAIResponseFormatText(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIJSONSchema(TypedDict, total=False):
|
class OpenAIJSONSchema(TypedDict, total=False):
|
||||||
name: str
|
name: str
|
||||||
description: str | None = None
|
description: str | None
|
||||||
strict: bool | None = None
|
strict: bool | None
|
||||||
|
|
||||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||||
# has one. And, we don't want to alias here because then have to handle
|
# has one. And, we don't want to alias here because then have to handle
|
||||||
# that alias when converting to OpenAI params. So, to support schema,
|
# that alias when converting to OpenAI params. So, to support schema,
|
||||||
# we use a TypedDict.
|
# we use a TypedDict.
|
||||||
schema: dict[str, Any] | None = None
|
schema: dict[str, Any] | None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -809,15 +820,32 @@ class BatchChatCompletionResponse(BaseModel):
|
||||||
batch: list[ChatCompletionResponse]
|
batch: list[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
||||||
|
input_messages: list[OpenAIMessageParam]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ListOpenAIChatCompletionResponse(BaseModel):
|
||||||
|
data: list[OpenAICompletionWithInputMessages]
|
||||||
|
has_more: bool
|
||||||
|
first_id: str
|
||||||
|
last_id: str
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
|
||||||
|
|
||||||
|
class Order(Enum):
|
||||||
|
asc = "asc"
|
||||||
|
desc = "desc"
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Inference(Protocol):
|
class InferenceProvider(Protocol):
|
||||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
|
||||||
|
|
||||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
|
||||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
|
||||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
|
||||||
"""
|
"""
|
||||||
|
This protocol defines the interface that should be implemented by all inference providers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
API_NAMESPACE: str = "Inference"
|
||||||
|
|
||||||
model_store: ModelStore | None = None
|
model_store: ModelStore | None = None
|
||||||
|
|
||||||
|
@ -834,13 +862,13 @@ class Inference(Protocol):
|
||||||
"""Generate a completion for the given content using the specified model.
|
"""Generate a completion for the given content using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
:param content: The content to generate a completion for
|
:param content: The content to generate a completion for.
|
||||||
:param sampling_params: (Optional) Parameters to control the sampling strategy
|
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding
|
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||||
:returns: If stream=False, returns a CompletionResponse with the full completion.
|
:returns: If stream=False, returns a CompletionResponse with the full completion.
|
||||||
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk
|
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -853,6 +881,15 @@ class Inference(Protocol):
|
||||||
response_format: ResponseFormat | None = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchCompletionResponse:
|
) -> BatchCompletionResponse:
|
||||||
|
"""Generate completions for a batch of content using the specified model.
|
||||||
|
|
||||||
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
:param content_batch: The content to generate completions for.
|
||||||
|
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||||
|
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||||
|
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||||
|
:returns: A BatchCompletionResponse with the full completions.
|
||||||
|
"""
|
||||||
raise NotImplementedError("Batch completion is not implemented")
|
raise NotImplementedError("Batch completion is not implemented")
|
||||||
|
|
||||||
@webmethod(route="/inference/chat-completion", method="POST")
|
@webmethod(route="/inference/chat-completion", method="POST")
|
||||||
|
@ -872,9 +909,9 @@ class Inference(Protocol):
|
||||||
"""Generate a chat completion for the given messages using the specified model.
|
"""Generate a chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
:param messages: List of messages in the conversation
|
:param messages: List of messages in the conversation.
|
||||||
:param sampling_params: Parameters to control the sampling strategy
|
:param sampling_params: Parameters to control the sampling strategy.
|
||||||
:param tools: (Optional) List of tool definitions available to the model
|
:param tools: (Optional) List of tool definitions available to the model.
|
||||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||||
.. deprecated::
|
.. deprecated::
|
||||||
Use tool_config instead.
|
Use tool_config instead.
|
||||||
|
@ -891,7 +928,7 @@ class Inference(Protocol):
|
||||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||||
:param tool_config: (Optional) Configuration for tool use.
|
:param tool_config: (Optional) Configuration for tool use.
|
||||||
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
||||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
|
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -906,6 +943,17 @@ class Inference(Protocol):
|
||||||
response_format: ResponseFormat | None = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
) -> BatchChatCompletionResponse:
|
) -> BatchChatCompletionResponse:
|
||||||
|
"""Generate chat completions for a batch of messages using the specified model.
|
||||||
|
|
||||||
|
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
:param messages_batch: The messages to generate completions for.
|
||||||
|
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||||
|
:param tools: (Optional) List of tool definitions available to the model.
|
||||||
|
:param tool_config: (Optional) Configuration for tool use.
|
||||||
|
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||||
|
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||||
|
:returns: A BatchChatCompletionResponse with the full completions.
|
||||||
|
"""
|
||||||
raise NotImplementedError("Batch chat completion is not implemented")
|
raise NotImplementedError("Batch chat completion is not implemented")
|
||||||
|
|
||||||
@webmethod(route="/inference/embeddings", method="POST")
|
@webmethod(route="/inference/embeddings", method="POST")
|
||||||
|
@ -924,7 +972,7 @@ class Inference(Protocol):
|
||||||
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
||||||
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
||||||
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -956,22 +1004,23 @@ class Inference(Protocol):
|
||||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||||
|
|
||||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
:param prompt: The prompt to generate a completion for
|
:param prompt: The prompt to generate a completion for.
|
||||||
:param best_of: (Optional) The number of completions to generate
|
:param best_of: (Optional) The number of completions to generate.
|
||||||
:param echo: (Optional) Whether to echo the prompt
|
:param echo: (Optional) Whether to echo the prompt.
|
||||||
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||||
:param logit_bias: (Optional) The logit bias to use
|
:param logit_bias: (Optional) The logit bias to use.
|
||||||
:param logprobs: (Optional) The log probabilities to use
|
:param logprobs: (Optional) The log probabilities to use.
|
||||||
:param max_tokens: (Optional) The maximum number of tokens to generate
|
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||||
:param n: (Optional) The number of completions to generate
|
:param n: (Optional) The number of completions to generate.
|
||||||
:param presence_penalty: (Optional) The penalty for repeated tokens
|
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||||
:param seed: (Optional) The seed to use
|
:param seed: (Optional) The seed to use.
|
||||||
:param stop: (Optional) The stop tokens to use
|
:param stop: (Optional) The stop tokens to use.
|
||||||
:param stream: (Optional) Whether to stream the response
|
:param stream: (Optional) Whether to stream the response.
|
||||||
:param stream_options: (Optional) The stream options to use
|
:param stream_options: (Optional) The stream options to use.
|
||||||
:param temperature: (Optional) The temperature to use
|
:param temperature: (Optional) The temperature to use.
|
||||||
:param top_p: (Optional) The top p to use
|
:param top_p: (Optional) The top p to use.
|
||||||
:param user: (Optional) The user to use
|
:param user: (Optional) The user to use.
|
||||||
|
:returns: An OpenAICompletion.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -1005,27 +1054,64 @@ class Inference(Protocol):
|
||||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
:param messages: List of messages in the conversation
|
:param messages: List of messages in the conversation.
|
||||||
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||||
:param function_call: (Optional) The function call to use
|
:param function_call: (Optional) The function call to use.
|
||||||
:param functions: (Optional) List of functions to use
|
:param functions: (Optional) List of functions to use.
|
||||||
:param logit_bias: (Optional) The logit bias to use
|
:param logit_bias: (Optional) The logit bias to use.
|
||||||
:param logprobs: (Optional) The log probabilities to use
|
:param logprobs: (Optional) The log probabilities to use.
|
||||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate
|
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
|
||||||
:param max_tokens: (Optional) The maximum number of tokens to generate
|
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||||
:param n: (Optional) The number of completions to generate
|
:param n: (Optional) The number of completions to generate.
|
||||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls
|
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
|
||||||
:param presence_penalty: (Optional) The penalty for repeated tokens
|
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||||
:param response_format: (Optional) The response format to use
|
:param response_format: (Optional) The response format to use.
|
||||||
:param seed: (Optional) The seed to use
|
:param seed: (Optional) The seed to use.
|
||||||
:param stop: (Optional) The stop tokens to use
|
:param stop: (Optional) The stop tokens to use.
|
||||||
:param stream: (Optional) Whether to stream the response
|
:param stream: (Optional) Whether to stream the response.
|
||||||
:param stream_options: (Optional) The stream options to use
|
:param stream_options: (Optional) The stream options to use.
|
||||||
:param temperature: (Optional) The temperature to use
|
:param temperature: (Optional) The temperature to use.
|
||||||
:param tool_choice: (Optional) The tool choice to use
|
:param tool_choice: (Optional) The tool choice to use.
|
||||||
:param tools: (Optional) The tools to use
|
:param tools: (Optional) The tools to use.
|
||||||
:param top_logprobs: (Optional) The top log probabilities to use
|
:param top_logprobs: (Optional) The top log probabilities to use.
|
||||||
:param top_p: (Optional) The top p to use
|
:param top_p: (Optional) The top p to use.
|
||||||
:param user: (Optional) The user to use
|
:param user: (Optional) The user to use.
|
||||||
|
:returns: An OpenAIChatCompletion.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Inference(InferenceProvider):
|
||||||
|
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||||
|
|
||||||
|
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||||
|
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||||
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/chat/completions", method="GET")
|
||||||
|
async def list_chat_completions(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 20,
|
||||||
|
model: str | None = None,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
) -> ListOpenAIChatCompletionResponse:
|
||||||
|
"""List all chat completions.
|
||||||
|
|
||||||
|
:param after: The ID of the last chat completion to return.
|
||||||
|
:param limit: The maximum number of chat completions to return.
|
||||||
|
:param model: The model to filter by.
|
||||||
|
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
|
||||||
|
:returns: A ListOpenAIChatCompletionResponse.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("List chat completions is not implemented")
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
|
||||||
|
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||||
|
"""Describe a chat completion by its ID.
|
||||||
|
|
||||||
|
:param completion_id: ID of the chat completion.
|
||||||
|
:returns: A OpenAICompletionWithInputMessages.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Get chat completion is not implemented")
|
||||||
|
|
|
@ -36,10 +36,25 @@ class ListRoutesResponse(BaseModel):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Inspect(Protocol):
|
class Inspect(Protocol):
|
||||||
@webmethod(route="/inspect/routes", method="GET")
|
@webmethod(route="/inspect/routes", method="GET")
|
||||||
async def list_routes(self) -> ListRoutesResponse: ...
|
async def list_routes(self) -> ListRoutesResponse:
|
||||||
|
"""List all routes.
|
||||||
|
|
||||||
|
:returns: A ListRoutesResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/health", method="GET")
|
@webmethod(route="/health", method="GET")
|
||||||
async def health(self) -> HealthInfo: ...
|
async def health(self) -> HealthInfo:
|
||||||
|
"""Get the health of the service.
|
||||||
|
|
||||||
|
:returns: A HealthInfo.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/version", method="GET")
|
@webmethod(route="/version", method="GET")
|
||||||
async def version(self) -> VersionInfo: ...
|
async def version(self) -> VersionInfo:
|
||||||
|
"""Get the version of the service.
|
||||||
|
|
||||||
|
:returns: A VersionInfo.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -29,14 +29,14 @@ class ModelType(str, Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Model(CommonModelFields, Resource):
|
class Model(CommonModelFields, Resource):
|
||||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
type: Literal[ResourceType.model] = ResourceType.model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_id(self) -> str:
|
def model_id(self) -> str:
|
||||||
return self.identifier
|
return self.identifier
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_model_id(self) -> str:
|
def provider_model_id(self) -> str | None:
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
@ -80,16 +80,32 @@ class OpenAIListModelsResponse(BaseModel):
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Models(Protocol):
|
class Models(Protocol):
|
||||||
@webmethod(route="/models", method="GET")
|
@webmethod(route="/models", method="GET")
|
||||||
async def list_models(self) -> ListModelsResponse: ...
|
async def list_models(self) -> ListModelsResponse:
|
||||||
|
"""List all models.
|
||||||
|
|
||||||
|
:returns: A ListModelsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/models", method="GET")
|
@webmethod(route="/openai/v1/models", method="GET")
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse: ...
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
|
"""List models using the OpenAI API.
|
||||||
|
|
||||||
|
:returns: A OpenAIListModelsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id:path}", method="GET")
|
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||||
async def get_model(
|
async def get_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
) -> Model: ...
|
) -> Model:
|
||||||
|
"""Get a model by its identifier.
|
||||||
|
|
||||||
|
:param model_id: The identifier of the model to get.
|
||||||
|
:returns: A Model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/models", method="POST")
|
@webmethod(route="/models", method="POST")
|
||||||
async def register_model(
|
async def register_model(
|
||||||
|
@ -99,10 +115,25 @@ class Models(Protocol):
|
||||||
provider_id: str | None = None,
|
provider_id: str | None = None,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
model_type: ModelType | None = None,
|
model_type: ModelType | None = None,
|
||||||
) -> Model: ...
|
) -> Model:
|
||||||
|
"""Register a model.
|
||||||
|
|
||||||
|
:param model_id: The identifier of the model to register.
|
||||||
|
:param provider_model_id: The identifier of the model in the provider.
|
||||||
|
:param provider_id: The identifier of the provider.
|
||||||
|
:param metadata: Any additional metadata for this model.
|
||||||
|
:param model_type: The type of model to register.
|
||||||
|
:returns: A Model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||||
async def unregister_model(
|
async def unregister_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
"""Unregister a model.
|
||||||
|
|
||||||
|
:param model_id: The identifier of the model to unregister.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -182,7 +182,19 @@ class PostTraining(Protocol):
|
||||||
),
|
),
|
||||||
checkpoint_dir: str | None = None,
|
checkpoint_dir: str | None = None,
|
||||||
algorithm_config: AlgorithmConfig | None = None,
|
algorithm_config: AlgorithmConfig | None = None,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob:
|
||||||
|
"""Run supervised fine-tuning of a model.
|
||||||
|
|
||||||
|
:param job_uuid: The UUID of the job to create.
|
||||||
|
:param training_config: The training configuration.
|
||||||
|
:param hyperparam_search_config: The hyperparam search configuration.
|
||||||
|
:param logger_config: The logger configuration.
|
||||||
|
:param model: The model to fine-tune.
|
||||||
|
:param checkpoint_dir: The directory to save checkpoint(s) to.
|
||||||
|
:param algorithm_config: The algorithm configuration.
|
||||||
|
:returns: A PostTrainingJob.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
|
@ -193,16 +205,49 @@ class PostTraining(Protocol):
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: dict[str, Any],
|
hyperparam_search_config: dict[str, Any],
|
||||||
logger_config: dict[str, Any],
|
logger_config: dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob:
|
||||||
|
"""Run preference optimization of a model.
|
||||||
|
|
||||||
|
:param job_uuid: The UUID of the job to create.
|
||||||
|
:param finetuned_model: The model to fine-tune.
|
||||||
|
:param algorithm_config: The algorithm configuration.
|
||||||
|
:param training_config: The training configuration.
|
||||||
|
:param hyperparam_search_config: The hyperparam search configuration.
|
||||||
|
:param logger_config: The logger configuration.
|
||||||
|
:returns: A PostTrainingJob.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs", method="GET")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||||
|
"""Get all training jobs.
|
||||||
|
|
||||||
|
:returns: A ListPostTrainingJobsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status", method="GET")
|
@webmethod(route="/post-training/job/status", method="GET")
|
||||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
|
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||||
|
"""Get the status of a training job.
|
||||||
|
|
||||||
|
:param job_uuid: The UUID of the job to get the status of.
|
||||||
|
:returns: A PostTrainingJobStatusResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
|
"""Cancel a training job.
|
||||||
|
|
||||||
|
:param job_uuid: The UUID of the job to cancel.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
|
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||||
|
"""Get the artifacts of a training job.
|
||||||
|
|
||||||
|
:param job_uuid: The UUID of the job to get the artifacts of.
|
||||||
|
:returns: A PostTrainingJobArtifactsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -32,7 +32,18 @@ class Providers(Protocol):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@webmethod(route="/providers", method="GET")
|
@webmethod(route="/providers", method="GET")
|
||||||
async def list_providers(self) -> ListProvidersResponse: ...
|
async def list_providers(self) -> ListProvidersResponse:
|
||||||
|
"""List all available providers.
|
||||||
|
|
||||||
|
:returns: A ListProvidersResponse containing information about all providers.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
|
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||||
|
"""Get detailed information about a specific provider.
|
||||||
|
|
||||||
|
:param provider_id: The ID of the provider to inspect.
|
||||||
|
:returns: A ProviderInfo object containing the provider's details.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -4,12 +4,23 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from enum import StrEnum
|
||||||
|
else:
|
||||||
|
|
||||||
class ResourceType(Enum):
|
class StrEnum(str, Enum):
|
||||||
|
"""Backport of StrEnum for Python 3.10 and below."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceType(StrEnum):
|
||||||
model = "model"
|
model = "model"
|
||||||
shield = "shield"
|
shield = "shield"
|
||||||
vector_db = "vector_db"
|
vector_db = "vector_db"
|
||||||
|
@ -25,9 +36,9 @@ class Resource(BaseModel):
|
||||||
|
|
||||||
identifier: str = Field(description="Unique identifier for this resource in llama stack")
|
identifier: str = Field(description="Unique identifier for this resource in llama stack")
|
||||||
|
|
||||||
provider_resource_id: str = Field(
|
provider_resource_id: str | None = Field(
|
||||||
description="Unique identifier for this resource in the provider",
|
|
||||||
default=None,
|
default=None,
|
||||||
|
description="Unique identifier for this resource in the provider",
|
||||||
)
|
)
|
||||||
|
|
||||||
provider_id: str = Field(description="ID of the provider that owns this resource")
|
provider_id: str = Field(description="ID of the provider that owns this resource")
|
||||||
|
|
|
@ -53,5 +53,13 @@ class Safety(Protocol):
|
||||||
self,
|
self,
|
||||||
shield_id: str,
|
shield_id: str,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
params: dict[str, Any] = None,
|
params: dict[str, Any],
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse:
|
||||||
|
"""Run a shield.
|
||||||
|
|
||||||
|
:param shield_id: The identifier of the shield to run.
|
||||||
|
:param messages: The messages to run the shield on.
|
||||||
|
:param params: The parameters of the shield.
|
||||||
|
:returns: A RunShieldResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -61,7 +61,15 @@ class Scoring(Protocol):
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: dict[str, ScoringFnParams | None],
|
scoring_functions: dict[str, ScoringFnParams | None],
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse: ...
|
) -> ScoreBatchResponse:
|
||||||
|
"""Score a batch of rows.
|
||||||
|
|
||||||
|
:param dataset_id: The ID of the dataset to score.
|
||||||
|
:param scoring_functions: The scoring functions to use for the scoring.
|
||||||
|
:param save_results_dataset: Whether to save the results to a dataset.
|
||||||
|
:returns: A ScoreBatchResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/scoring/score", method="POST")
|
@webmethod(route="/scoring/score", method="POST")
|
||||||
async def score(
|
async def score(
|
||||||
|
@ -73,6 +81,6 @@ class Scoring(Protocol):
|
||||||
|
|
||||||
:param input_rows: The rows to score.
|
:param input_rows: The rows to score.
|
||||||
:param scoring_functions: The scoring functions to use for the scoring.
|
:param scoring_functions: The scoring functions to use for the scoring.
|
||||||
:return: ScoreResponse object containing rows and aggregated results
|
:returns: A ScoreResponse object containing rows and aggregated results.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||||
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
|
@ -19,18 +21,27 @@ from llama_stack.apis.common.type_system import ParamType
|
||||||
from llama_stack.apis.resource import Resource, ResourceType
|
from llama_stack.apis.resource import Resource, ResourceType
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from enum import StrEnum
|
||||||
|
else:
|
||||||
|
|
||||||
|
class StrEnum(str, Enum):
|
||||||
|
"""Backport of StrEnum for Python 3.10 and below."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||||
# with standard metrics so they can be rolled up?
|
# with standard metrics so they can be rolled up?
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoringFnParamsType(Enum):
|
class ScoringFnParamsType(StrEnum):
|
||||||
llm_as_judge = "llm_as_judge"
|
llm_as_judge = "llm_as_judge"
|
||||||
regex_parser = "regex_parser"
|
regex_parser = "regex_parser"
|
||||||
basic = "basic"
|
basic = "basic"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AggregationFunctionType(Enum):
|
class AggregationFunctionType(StrEnum):
|
||||||
average = "average"
|
average = "average"
|
||||||
weighted_average = "weighted_average"
|
weighted_average = "weighted_average"
|
||||||
median = "median"
|
median = "median"
|
||||||
|
@ -40,36 +51,36 @@ class AggregationFunctionType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
|
||||||
judge_model: str
|
judge_model: str
|
||||||
prompt_template: str | None = None
|
prompt_template: str | None = None
|
||||||
judge_score_regexes: list[str] | None = Field(
|
judge_score_regexes: list[str] = Field(
|
||||||
description="Regexes to extract the answer from generated response",
|
description="Regexes to extract the answer from generated response",
|
||||||
default_factory=list,
|
default_factory=lambda: [],
|
||||||
)
|
)
|
||||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=lambda: [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RegexParserScoringFnParams(BaseModel):
|
class RegexParserScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
|
||||||
parsing_regexes: list[str] | None = Field(
|
parsing_regexes: list[str] = Field(
|
||||||
description="Regex to extract the answer from generated response",
|
description="Regex to extract the answer from generated response",
|
||||||
default_factory=list,
|
default_factory=lambda: [],
|
||||||
)
|
)
|
||||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=lambda: [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BasicScoringFnParams(BaseModel):
|
class BasicScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
|
||||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||||
description="Aggregation functions to apply to the scores of each row",
|
description="Aggregation functions to apply to the scores of each row",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
@ -99,14 +110,14 @@ class CommonScoringFnFields(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoringFn(CommonScoringFnFields, Resource):
|
class ScoringFn(CommonScoringFnFields, Resource):
|
||||||
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
|
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scoring_fn_id(self) -> str:
|
def scoring_fn_id(self) -> str:
|
||||||
return self.identifier
|
return self.identifier
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_scoring_fn_id(self) -> str:
|
def provider_scoring_fn_id(self) -> str | None:
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -123,10 +134,21 @@ class ListScoringFunctionsResponse(BaseModel):
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class ScoringFunctions(Protocol):
|
class ScoringFunctions(Protocol):
|
||||||
@webmethod(route="/scoring-functions", method="GET")
|
@webmethod(route="/scoring-functions", method="GET")
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||||
|
"""List all scoring functions.
|
||||||
|
|
||||||
|
:returns: A ListScoringFunctionsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
|
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
|
||||||
|
"""Get a scoring function by its ID.
|
||||||
|
|
||||||
|
:param scoring_fn_id: The ID of the scoring function to get.
|
||||||
|
:returns: A ScoringFn.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions", method="POST")
|
@webmethod(route="/scoring-functions", method="POST")
|
||||||
async def register_scoring_function(
|
async def register_scoring_function(
|
||||||
|
@ -137,4 +159,14 @@ class ScoringFunctions(Protocol):
|
||||||
provider_scoring_fn_id: str | None = None,
|
provider_scoring_fn_id: str | None = None,
|
||||||
provider_id: str | None = None,
|
provider_id: str | None = None,
|
||||||
params: ScoringFnParams | None = None,
|
params: ScoringFnParams | None = None,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
"""Register a scoring function.
|
||||||
|
|
||||||
|
:param scoring_fn_id: The ID of the scoring function to register.
|
||||||
|
:param description: The description of the scoring function.
|
||||||
|
:param return_type: The return type of the scoring function.
|
||||||
|
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
|
||||||
|
:param provider_id: The ID of the provider to use for the scoring function.
|
||||||
|
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -21,14 +21,14 @@ class CommonShieldFields(BaseModel):
|
||||||
class Shield(CommonShieldFields, Resource):
|
class Shield(CommonShieldFields, Resource):
|
||||||
"""A safety shield resource that can be used to check content"""
|
"""A safety shield resource that can be used to check content"""
|
||||||
|
|
||||||
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
|
type: Literal[ResourceType.shield] = ResourceType.shield
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shield_id(self) -> str:
|
def shield_id(self) -> str:
|
||||||
return self.identifier
|
return self.identifier
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_shield_id(self) -> str:
|
def provider_shield_id(self) -> str | None:
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,10 +46,21 @@ class ListShieldsResponse(BaseModel):
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Shields(Protocol):
|
class Shields(Protocol):
|
||||||
@webmethod(route="/shields", method="GET")
|
@webmethod(route="/shields", method="GET")
|
||||||
async def list_shields(self) -> ListShieldsResponse: ...
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
|
"""List all shields.
|
||||||
|
|
||||||
|
:returns: A ListShieldsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/shields/{identifier:path}", method="GET")
|
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||||
async def get_shield(self, identifier: str) -> Shield: ...
|
async def get_shield(self, identifier: str) -> Shield:
|
||||||
|
"""Get a shield by its identifier.
|
||||||
|
|
||||||
|
:param identifier: The identifier of the shield to get.
|
||||||
|
:returns: A Shield.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/shields", method="POST")
|
@webmethod(route="/shields", method="POST")
|
||||||
async def register_shield(
|
async def register_shield(
|
||||||
|
@ -58,4 +69,13 @@ class Shields(Protocol):
|
||||||
provider_shield_id: str | None = None,
|
provider_shield_id: str | None = None,
|
||||||
provider_id: str | None = None,
|
provider_id: str | None = None,
|
||||||
params: dict[str, Any] | None = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> Shield: ...
|
) -> Shield:
|
||||||
|
"""Register a shield.
|
||||||
|
|
||||||
|
:param shield_id: The identifier of the shield to register.
|
||||||
|
:param provider_shield_id: The identifier of the shield in the provider.
|
||||||
|
:param provider_id: The identifier of the provider.
|
||||||
|
:param params: The parameters of the shield.
|
||||||
|
:returns: A Shield.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -37,7 +37,7 @@ class Span(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
start_time: datetime
|
start_time: datetime
|
||||||
end_time: datetime | None = None
|
end_time: datetime | None = None
|
||||||
attributes: dict[str, Any] | None = Field(default_factory=dict)
|
attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||||
|
|
||||||
def set_attribute(self, key: str, value: Any):
|
def set_attribute(self, key: str, value: Any):
|
||||||
if self.attributes is None:
|
if self.attributes is None:
|
||||||
|
@ -74,19 +74,19 @@ class EventCommon(BaseModel):
|
||||||
trace_id: str
|
trace_id: str
|
||||||
span_id: str
|
span_id: str
|
||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
attributes: dict[str, Primitive] | None = Field(default_factory=dict)
|
attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class UnstructuredLogEvent(EventCommon):
|
class UnstructuredLogEvent(EventCommon):
|
||||||
type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value
|
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
|
||||||
message: str
|
message: str
|
||||||
severity: LogSeverity
|
severity: LogSeverity
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class MetricEvent(EventCommon):
|
class MetricEvent(EventCommon):
|
||||||
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
type: Literal[EventType.METRIC] = EventType.METRIC
|
||||||
metric: str # this would be an enum
|
metric: str # this would be an enum
|
||||||
value: int | float
|
value: int | float
|
||||||
unit: str
|
unit: str
|
||||||
|
@ -131,14 +131,14 @@ class StructuredLogType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SpanStartPayload(BaseModel):
|
class SpanStartPayload(BaseModel):
|
||||||
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
|
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
|
||||||
name: str
|
name: str
|
||||||
parent_span_id: str | None = None
|
parent_span_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SpanEndPayload(BaseModel):
|
class SpanEndPayload(BaseModel):
|
||||||
type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value
|
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
|
||||||
status: SpanStatus
|
status: SpanStatus
|
||||||
|
|
||||||
|
|
||||||
|
@ -151,7 +151,7 @@ register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class StructuredLogEvent(EventCommon):
|
class StructuredLogEvent(EventCommon):
|
||||||
type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value
|
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
|
||||||
payload: StructuredLogPayload
|
payload: StructuredLogPayload
|
||||||
|
|
||||||
|
|
||||||
|
@ -203,10 +203,61 @@ class QuerySpanTreeResponse(BaseModel):
|
||||||
data: dict[str, SpanWithStatus]
|
data: dict[str, SpanWithStatus]
|
||||||
|
|
||||||
|
|
||||||
|
class MetricQueryType(Enum):
|
||||||
|
RANGE = "range"
|
||||||
|
INSTANT = "instant"
|
||||||
|
|
||||||
|
|
||||||
|
class MetricLabelOperator(Enum):
|
||||||
|
EQUALS = "="
|
||||||
|
NOT_EQUALS = "!="
|
||||||
|
REGEX_MATCH = "=~"
|
||||||
|
REGEX_NOT_MATCH = "!~"
|
||||||
|
|
||||||
|
|
||||||
|
class MetricLabelMatcher(BaseModel):
|
||||||
|
name: str
|
||||||
|
value: str
|
||||||
|
operator: MetricLabelOperator = MetricLabelOperator.EQUALS
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MetricLabel(BaseModel):
|
||||||
|
name: str
|
||||||
|
value: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MetricDataPoint(BaseModel):
|
||||||
|
timestamp: int
|
||||||
|
value: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MetricSeries(BaseModel):
|
||||||
|
metric: str
|
||||||
|
labels: list[MetricLabel]
|
||||||
|
values: list[MetricDataPoint]
|
||||||
|
|
||||||
|
|
||||||
|
class QueryMetricsResponse(BaseModel):
|
||||||
|
data: list[MetricSeries]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Telemetry(Protocol):
|
class Telemetry(Protocol):
|
||||||
@webmethod(route="/telemetry/events", method="POST")
|
@webmethod(route="/telemetry/events", method="POST")
|
||||||
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
|
async def log_event(
|
||||||
|
self,
|
||||||
|
event: Event,
|
||||||
|
ttl_seconds: int = DEFAULT_TTL_DAYS * 86400,
|
||||||
|
) -> None:
|
||||||
|
"""Log an event.
|
||||||
|
|
||||||
|
:param event: The event to log.
|
||||||
|
:param ttl_seconds: The time to live of the event.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces", method="POST")
|
@webmethod(route="/telemetry/traces", method="POST")
|
||||||
async def query_traces(
|
async def query_traces(
|
||||||
|
@ -215,13 +266,35 @@ class Telemetry(Protocol):
|
||||||
limit: int | None = 100,
|
limit: int | None = 100,
|
||||||
offset: int | None = 0,
|
offset: int | None = 0,
|
||||||
order_by: list[str] | None = None,
|
order_by: list[str] | None = None,
|
||||||
) -> QueryTracesResponse: ...
|
) -> QueryTracesResponse:
|
||||||
|
"""Query traces.
|
||||||
|
|
||||||
|
:param attribute_filters: The attribute filters to apply to the traces.
|
||||||
|
:param limit: The limit of traces to return.
|
||||||
|
:param offset: The offset of the traces to return.
|
||||||
|
:param order_by: The order by of the traces to return.
|
||||||
|
:returns: A QueryTracesResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
async def get_trace(self, trace_id: str) -> Trace:
|
||||||
|
"""Get a trace by its ID.
|
||||||
|
|
||||||
|
:param trace_id: The ID of the trace to get.
|
||||||
|
:returns: A Trace.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
||||||
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
|
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||||
|
"""Get a span by its ID.
|
||||||
|
|
||||||
|
:param trace_id: The ID of the trace to get the span from.
|
||||||
|
:param span_id: The ID of the span to get.
|
||||||
|
:returns: A Span.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
|
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
|
||||||
async def get_span_tree(
|
async def get_span_tree(
|
||||||
|
@ -229,7 +302,15 @@ class Telemetry(Protocol):
|
||||||
span_id: str,
|
span_id: str,
|
||||||
attributes_to_return: list[str] | None = None,
|
attributes_to_return: list[str] | None = None,
|
||||||
max_depth: int | None = None,
|
max_depth: int | None = None,
|
||||||
) -> QuerySpanTreeResponse: ...
|
) -> QuerySpanTreeResponse:
|
||||||
|
"""Get a span tree by its ID.
|
||||||
|
|
||||||
|
:param span_id: The ID of the span to get the tree from.
|
||||||
|
:param attributes_to_return: The attributes to return in the tree.
|
||||||
|
:param max_depth: The maximum depth of the tree.
|
||||||
|
:returns: A QuerySpanTreeResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans", method="POST")
|
@webmethod(route="/telemetry/spans", method="POST")
|
||||||
async def query_spans(
|
async def query_spans(
|
||||||
|
@ -237,7 +318,15 @@ class Telemetry(Protocol):
|
||||||
attribute_filters: list[QueryCondition],
|
attribute_filters: list[QueryCondition],
|
||||||
attributes_to_return: list[str],
|
attributes_to_return: list[str],
|
||||||
max_depth: int | None = None,
|
max_depth: int | None = None,
|
||||||
) -> QuerySpansResponse: ...
|
) -> QuerySpansResponse:
|
||||||
|
"""Query spans.
|
||||||
|
|
||||||
|
:param attribute_filters: The attribute filters to apply to the spans.
|
||||||
|
:param attributes_to_return: The attributes to return in the spans.
|
||||||
|
:param max_depth: The maximum depth of the tree.
|
||||||
|
:returns: A QuerySpansResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans/export", method="POST")
|
@webmethod(route="/telemetry/spans/export", method="POST")
|
||||||
async def save_spans_to_dataset(
|
async def save_spans_to_dataset(
|
||||||
|
@ -246,4 +335,34 @@ class Telemetry(Protocol):
|
||||||
attributes_to_save: list[str],
|
attributes_to_save: list[str],
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
max_depth: int | None = None,
|
max_depth: int | None = None,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
"""Save spans to a dataset.
|
||||||
|
|
||||||
|
:param attribute_filters: The attribute filters to apply to the spans.
|
||||||
|
:param attributes_to_save: The attributes to save to the dataset.
|
||||||
|
:param dataset_id: The ID of the dataset to save the spans to.
|
||||||
|
:param max_depth: The maximum depth of the tree.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
|
||||||
|
async def query_metrics(
|
||||||
|
self,
|
||||||
|
metric_name: str,
|
||||||
|
start_time: int,
|
||||||
|
end_time: int | None = None,
|
||||||
|
granularity: str | None = "1d",
|
||||||
|
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||||
|
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||||
|
) -> QueryMetricsResponse:
|
||||||
|
"""Query metrics.
|
||||||
|
|
||||||
|
:param metric_name: The name of the metric to query.
|
||||||
|
:param start_time: The start time of the metric to query.
|
||||||
|
:param end_time: The end time of the metric to query.
|
||||||
|
:param granularity: The granularity of the metric to query.
|
||||||
|
:param query_type: The type of query to perform.
|
||||||
|
:param label_matchers: The label matchers to apply to the metric.
|
||||||
|
:returns: A QueryMetricsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Protocol, runtime_checkable
|
from typing_extensions import Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||||
|
@ -67,11 +67,33 @@ register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RAGQueryConfig(BaseModel):
|
class RAGQueryConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for the RAG query generation.
|
||||||
|
|
||||||
|
:param query_generator_config: Configuration for the query generator.
|
||||||
|
:param max_tokens_in_context: Maximum number of tokens in the context.
|
||||||
|
:param max_chunks: Maximum number of chunks to retrieve.
|
||||||
|
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
||||||
|
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||||
|
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||||
|
"""
|
||||||
|
|
||||||
# This config defines how a query is generated using the messages
|
# This config defines how a query is generated using the messages
|
||||||
# for memory bank retrieval.
|
# for memory bank retrieval.
|
||||||
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
|
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
|
||||||
max_tokens_in_context: int = 4096
|
max_tokens_in_context: int = 4096
|
||||||
max_chunks: int = 5
|
max_chunks: int = 5
|
||||||
|
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||||
|
|
||||||
|
@field_validator("chunk_template")
|
||||||
|
def validate_chunk_template(cls, v: str) -> str:
|
||||||
|
if "{chunk.content}" not in v:
|
||||||
|
raise ValueError("chunk_template must contain {chunk.content}")
|
||||||
|
if "{index}" not in v:
|
||||||
|
raise ValueError("chunk_template must contain {index}")
|
||||||
|
if len(v) == 0:
|
||||||
|
raise ValueError("chunk_template must not be empty")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
@ -36,7 +36,7 @@ class ToolHost(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Tool(Resource):
|
class Tool(Resource):
|
||||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
tool_host: ToolHost
|
tool_host: ToolHost
|
||||||
description: str
|
description: str
|
||||||
|
@ -62,7 +62,7 @@ class ToolGroupInput(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ToolGroup(Resource):
|
class ToolGroup(Resource):
|
||||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
type: Literal[ResourceType.tool_group] = ResourceType.tool_group
|
||||||
mcp_endpoint: URL | None = None
|
mcp_endpoint: URL | None = None
|
||||||
args: dict[str, Any] | None = None
|
args: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
@ -103,37 +103,65 @@ class ToolGroups(Protocol):
|
||||||
mcp_endpoint: URL | None = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: dict[str, Any] | None = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a tool group"""
|
"""Register a tool group.
|
||||||
|
|
||||||
|
:param toolgroup_id: The ID of the tool group to register.
|
||||||
|
:param provider_id: The ID of the provider to use for the tool group.
|
||||||
|
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
||||||
|
:param args: A dictionary of arguments to pass to the tool group.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
|
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
|
||||||
async def get_tool_group(
|
async def get_tool_group(
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
) -> ToolGroup: ...
|
) -> ToolGroup:
|
||||||
|
"""Get a tool group by its ID.
|
||||||
|
|
||||||
|
:param toolgroup_id: The ID of the tool group to get.
|
||||||
|
:returns: A ToolGroup.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/toolgroups", method="GET")
|
@webmethod(route="/toolgroups", method="GET")
|
||||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||||
"""List tool groups with optional provider"""
|
"""List tool groups with optional provider.
|
||||||
|
|
||||||
|
:returns: A ListToolGroupsResponse.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/tools", method="GET")
|
@webmethod(route="/tools", method="GET")
|
||||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
"""List tools with optional tool group"""
|
"""List tools with optional tool group.
|
||||||
|
|
||||||
|
:param toolgroup_id: The ID of the tool group to list tools for.
|
||||||
|
:returns: A ListToolsResponse.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/tools/{tool_name:path}", method="GET")
|
@webmethod(route="/tools/{tool_name:path}", method="GET")
|
||||||
async def get_tool(
|
async def get_tool(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
) -> Tool: ...
|
) -> Tool:
|
||||||
|
"""Get a tool by its name.
|
||||||
|
|
||||||
|
:param tool_name: The name of the tool to get.
|
||||||
|
:returns: A Tool.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
|
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
|
||||||
async def unregister_toolgroup(
|
async def unregister_toolgroup(
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Unregister a tool group"""
|
"""Unregister a tool group.
|
||||||
|
|
||||||
|
:param toolgroup_id: The ID of the tool group to unregister.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@ -152,9 +180,21 @@ class ToolRuntime(Protocol):
|
||||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse: ...
|
) -> ListToolDefsResponse:
|
||||||
|
"""List all tools in the runtime.
|
||||||
|
|
||||||
|
:param tool_group_id: The ID of the tool group to list tools for.
|
||||||
|
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
||||||
|
:returns: A ListToolDefsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||||
"""Run a tool with the given arguments"""
|
"""Run a tool with the given arguments.
|
||||||
|
|
||||||
|
:param tool_name: The name of the tool to invoke.
|
||||||
|
:param kwargs: A dictionary of arguments to pass to the tool.
|
||||||
|
:returns: A ToolInvocationResult.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class VectorDB(Resource):
|
class VectorDB(Resource):
|
||||||
type: Literal[ResourceType.vector_db.value] = ResourceType.vector_db.value
|
type: Literal[ResourceType.vector_db] = ResourceType.vector_db
|
||||||
|
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
embedding_dimension: int
|
embedding_dimension: int
|
||||||
|
@ -25,7 +25,7 @@ class VectorDB(Resource):
|
||||||
return self.identifier
|
return self.identifier
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_vector_db_id(self) -> str:
|
def provider_vector_db_id(self) -> str | None:
|
||||||
return self.provider_resource_id
|
return self.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,13 +44,24 @@ class ListVectorDBsResponse(BaseModel):
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class VectorDBs(Protocol):
|
class VectorDBs(Protocol):
|
||||||
@webmethod(route="/vector-dbs", method="GET")
|
@webmethod(route="/vector-dbs", method="GET")
|
||||||
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
|
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||||
|
"""List all vector databases.
|
||||||
|
|
||||||
|
:returns: A ListVectorDBsResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
|
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
|
||||||
async def get_vector_db(
|
async def get_vector_db(
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
) -> VectorDB: ...
|
) -> VectorDB:
|
||||||
|
"""Get a vector database by its identifier.
|
||||||
|
|
||||||
|
:param vector_db_id: The identifier of the vector database to get.
|
||||||
|
:returns: A VectorDB.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/vector-dbs", method="POST")
|
@webmethod(route="/vector-dbs", method="POST")
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
|
@ -60,7 +71,22 @@ class VectorDBs(Protocol):
|
||||||
embedding_dimension: int | None = 384,
|
embedding_dimension: int | None = 384,
|
||||||
provider_id: str | None = None,
|
provider_id: str | None = None,
|
||||||
provider_vector_db_id: str | None = None,
|
provider_vector_db_id: str | None = None,
|
||||||
) -> VectorDB: ...
|
) -> VectorDB:
|
||||||
|
"""Register a vector database.
|
||||||
|
|
||||||
|
:param vector_db_id: The identifier of the vector database to register.
|
||||||
|
:param embedding_model: The embedding model to use.
|
||||||
|
:param embedding_dimension: The dimension of the embedding model.
|
||||||
|
:param provider_id: The identifier of the provider.
|
||||||
|
:param provider_vector_db_id: The identifier of the vector database in the provider.
|
||||||
|
:returns: A VectorDB.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
"""Unregister a vector database.
|
||||||
|
|
||||||
|
:param vector_db_id: The identifier of the vector database to unregister.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -46,7 +46,14 @@ class VectorIO(Protocol):
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunks: list[Chunk],
|
chunks: list[Chunk],
|
||||||
ttl_seconds: int | None = None,
|
ttl_seconds: int | None = None,
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
"""Insert chunks into a vector database.
|
||||||
|
|
||||||
|
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
||||||
|
:param chunks: The chunks to insert.
|
||||||
|
:param ttl_seconds: The time to live of the chunks.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@webmethod(route="/vector-io/query", method="POST")
|
@webmethod(route="/vector-io/query", method="POST")
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
|
@ -54,4 +61,12 @@ class VectorIO(Protocol):
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
query: InterleavedContent,
|
query: InterleavedContent,
|
||||||
params: dict[str, Any] | None = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse: ...
|
) -> QueryChunksResponse:
|
||||||
|
"""Query chunks from a vector database.
|
||||||
|
|
||||||
|
:param vector_db_id: The identifier of the vector database to query.
|
||||||
|
:param query: The query to search for.
|
||||||
|
:param params: The parameters of the query.
|
||||||
|
:returns: A QueryChunksResponse.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -38,7 +38,10 @@ class LlamaCLIParser:
|
||||||
print_subcommand_description(self.parser, subparsers)
|
print_subcommand_description(self.parser, subparsers)
|
||||||
|
|
||||||
def parse_args(self) -> argparse.Namespace:
|
def parse_args(self) -> argparse.Namespace:
|
||||||
return self.parser.parse_args()
|
args = self.parser.parse_args()
|
||||||
|
if not isinstance(args, argparse.Namespace):
|
||||||
|
raise TypeError(f"Expected argparse.Namespace, got {type(args)}")
|
||||||
|
return args
|
||||||
|
|
||||||
def run(self, args: argparse.Namespace) -> None:
|
def run(self, args: argparse.Namespace) -> None:
|
||||||
args.func(args)
|
args.func(args)
|
||||||
|
|
|
@ -12,6 +12,7 @@ import shutil
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from importlib.abc import Traversable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -36,7 +37,8 @@ from llama_stack.distribution.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.stack import replace_env_vars
|
||||||
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||||
|
@ -202,7 +204,11 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
else:
|
else:
|
||||||
with open(args.config) as f:
|
with open(args.config) as f:
|
||||||
try:
|
try:
|
||||||
build_config = BuildConfig(**yaml.safe_load(f))
|
contents = yaml.safe_load(f)
|
||||||
|
contents = replace_env_vars(contents)
|
||||||
|
build_config = BuildConfig(**contents)
|
||||||
|
if args.image_type:
|
||||||
|
build_config.image_type = args.image_type
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cprint(
|
cprint(
|
||||||
f"Could not parse config file {args.config}: {e}",
|
f"Could not parse config file {args.config}: {e}",
|
||||||
|
@ -245,11 +251,12 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if args.run:
|
if args.run:
|
||||||
run_config = Path(run_config)
|
|
||||||
config_dict = yaml.safe_load(run_config.read_text())
|
config_dict = yaml.safe_load(run_config.read_text())
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
if not os.path.exists(config.external_providers_dir):
|
||||||
|
os.makedirs(config.external_providers_dir, exist_ok=True)
|
||||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||||
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
|
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
||||||
run_command(run_args)
|
run_command(run_args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -257,7 +264,7 @@ def _generate_run_config(
|
||||||
build_config: BuildConfig,
|
build_config: BuildConfig,
|
||||||
build_dir: Path,
|
build_dir: Path,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> str:
|
) -> Path:
|
||||||
"""
|
"""
|
||||||
Generate a run.yaml template file for user to edit from a build.yaml file
|
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||||
"""
|
"""
|
||||||
|
@ -267,7 +274,9 @@ def _generate_run_config(
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
apis=apis,
|
apis=apis,
|
||||||
providers={},
|
providers={},
|
||||||
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None,
|
external_providers_dir=build_config.external_providers_dir
|
||||||
|
if build_config.external_providers_dir
|
||||||
|
else EXTERNAL_PROVIDERS_DIR,
|
||||||
)
|
)
|
||||||
# build providers dict
|
# build providers dict
|
||||||
provider_registry = get_provider_registry(build_config)
|
provider_registry = get_provider_registry(build_config)
|
||||||
|
@ -334,7 +343,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
image_name: str | None = None,
|
image_name: str | None = None,
|
||||||
template_name: str | None = None,
|
template_name: str | None = None,
|
||||||
config_path: str | None = None,
|
config_path: str | None = None,
|
||||||
) -> str:
|
) -> Path | Traversable:
|
||||||
image_name = image_name or build_config.image_name
|
image_name = image_name or build_config.image_name
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
if template_name:
|
if template_name:
|
||||||
|
|
|
@ -49,7 +49,7 @@ class StackBuild(Subcommand):
|
||||||
type=str,
|
type=str,
|
||||||
help="Image Type to use for the build. If not specified, will use the image type from the template config.",
|
help="Image Type to use for the build. If not specified, will use the image type from the template config.",
|
||||||
choices=[e.value for e in ImageType],
|
choices=[e.value for e in ImageType],
|
||||||
default=ImageType.CONDA.value,
|
default=None, # no default so we can detect if a user specified --image-type and override image_type in the config
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
|
|
|
@ -46,7 +46,7 @@ class StackListProviders(Subcommand):
|
||||||
else:
|
else:
|
||||||
providers = [(k.value, prov) for k, prov in all_providers.items()]
|
providers = [(k.value, prov) for k, prov in all_providers.items()]
|
||||||
|
|
||||||
providers = [p for api, p in providers if api in self.providable_apis]
|
providers = [(api, p) for api, p in providers if api in self.providable_apis]
|
||||||
|
|
||||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||||
headers = [
|
headers = [
|
||||||
|
@ -57,7 +57,7 @@ class StackListProviders(Subcommand):
|
||||||
|
|
||||||
rows = []
|
rows = []
|
||||||
|
|
||||||
specs = [spec for p in providers for spec in p.values()]
|
specs = [spec for api, p in providers for spec in p.values()]
|
||||||
for spec in specs:
|
for spec in specs:
|
||||||
if spec.is_sample:
|
if spec.is_sample:
|
||||||
continue
|
continue
|
||||||
|
@ -65,7 +65,7 @@ class StackListProviders(Subcommand):
|
||||||
[
|
[
|
||||||
spec.api.value,
|
spec.api.value,
|
||||||
spec.provider_type,
|
spec.provider_type,
|
||||||
",".join(spec.pip_packages),
|
",".join(spec.pip_packages) if hasattr(spec, "pip_packages") else "",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
print_table(
|
print_table(
|
||||||
|
|
|
@ -33,7 +33,8 @@ class StackRun(Subcommand):
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"config",
|
"config",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to config file to use for the run",
|
nargs="?", # Make it optional
|
||||||
|
help="Path to config file to use for the run. Required for venv and conda environments.",
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--port",
|
"--port",
|
||||||
|
@ -47,28 +48,12 @@ class StackRun(Subcommand):
|
||||||
default=os.environ.get("CONDA_DEFAULT_ENV"),
|
default=os.environ.get("CONDA_DEFAULT_ENV"),
|
||||||
help="Name of the image to run. Defaults to the current environment",
|
help="Name of the image to run. Defaults to the current environment",
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
|
||||||
"--disable-ipv6",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable IPv6 support",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--env",
|
"--env",
|
||||||
action="append",
|
action="append",
|
||||||
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
|
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
|
||||||
metavar="KEY=VALUE",
|
metavar="KEY=VALUE",
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
|
||||||
"--tls-keyfile",
|
|
||||||
type=str,
|
|
||||||
help="Path to TLS key file for HTTPS",
|
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
|
||||||
"--tls-certfile",
|
|
||||||
type=str,
|
|
||||||
help="Path to TLS certificate file for HTTPS",
|
|
||||||
)
|
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--image-type",
|
"--image-type",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -98,6 +83,13 @@ class StackRun(Subcommand):
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||||
|
|
||||||
|
image_type, image_name = self._get_image_type_and_name(args)
|
||||||
|
|
||||||
|
# Check if config is required based on image type
|
||||||
|
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not args.config:
|
||||||
|
self.parser.error("Config file is required for venv and conda environments")
|
||||||
|
|
||||||
|
if args.config:
|
||||||
config_file = Path(args.config)
|
config_file = Path(args.config)
|
||||||
has_yaml_suffix = args.config.endswith(".yaml")
|
has_yaml_suffix = args.config.endswith(".yaml")
|
||||||
template_name = None
|
template_name = None
|
||||||
|
@ -131,10 +123,14 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
if not os.path.exists(str(config.external_providers_dir)):
|
||||||
|
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||||
|
else:
|
||||||
image_type, image_name = self._get_image_type_and_name(args)
|
config = None
|
||||||
|
config_file = None
|
||||||
|
template_name = None
|
||||||
|
|
||||||
# If neither image type nor image name is provided, assume the server should be run directly
|
# If neither image type nor image name is provided, assume the server should be run directly
|
||||||
# using the current environment packages.
|
# using the current environment packages.
|
||||||
|
@ -157,9 +153,10 @@ class StackRun(Subcommand):
|
||||||
else:
|
else:
|
||||||
run_args = formulate_run_args(image_type, image_name, config, template_name)
|
run_args = formulate_run_args(image_type, image_name, config, template_name)
|
||||||
|
|
||||||
run_args.extend([str(config_file), str(args.port)])
|
run_args.extend([str(args.port)])
|
||||||
if args.disable_ipv6:
|
|
||||||
run_args.append("--disable-ipv6")
|
if config_file:
|
||||||
|
run_args.extend(["--config", str(config_file)])
|
||||||
|
|
||||||
if args.env:
|
if args.env:
|
||||||
for env_var in args.env:
|
for env_var in args.env:
|
||||||
|
@ -172,6 +169,4 @@ class StackRun(Subcommand):
|
||||||
return
|
return
|
||||||
run_args.extend(["--env", f"{key}={value}"])
|
run_args.extend(["--env", f"{key}={value}"])
|
||||||
|
|
||||||
if args.tls_keyfile and args.tls_certfile:
|
|
||||||
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
|
|
||||||
run_command(run_args)
|
run_command(run_args)
|
||||||
|
|
|
@ -154,6 +154,12 @@ get_python_cmd() {
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Add other required item commands generic to all containers
|
||||||
|
add_to_container << EOF
|
||||||
|
# Allows running as non-root user
|
||||||
|
RUN mkdir -p /.llama/providers.d /.cache
|
||||||
|
EOF
|
||||||
|
|
||||||
if [ -n "$run_config" ]; then
|
if [ -n "$run_config" ]; then
|
||||||
# Copy the run config to the build context since it's an absolute path
|
# Copy the run config to the build context since it's an absolute path
|
||||||
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
|
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
|
@ -166,17 +172,19 @@ EOF
|
||||||
# and update the configuration to reference the new container path
|
# and update the configuration to reference the new container path
|
||||||
python_cmd=$(get_python_cmd)
|
python_cmd=$(get_python_cmd)
|
||||||
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
|
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
|
||||||
if [ -n "$external_providers_dir" ]; then
|
external_providers_dir=$(eval echo "$external_providers_dir")
|
||||||
|
if [ -n "$external_providers_dir" ] && [ -d "$external_providers_dir" ]; then
|
||||||
echo "Copying external providers directory: $external_providers_dir"
|
echo "Copying external providers directory: $external_providers_dir"
|
||||||
|
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
COPY $external_providers_dir /app/providers.d
|
COPY providers.d /.llama/providers.d
|
||||||
EOF
|
EOF
|
||||||
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d
|
# Edit the run.yaml file to change the external_providers_dir to /.llama/providers.d
|
||||||
if [ "$(uname)" = "Darwin" ]; then
|
if [ "$(uname)" = "Darwin" ]; then
|
||||||
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
|
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
|
||||||
else
|
else
|
||||||
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
sed -i 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -255,9 +263,6 @@ fi
|
||||||
# Add other require item commands genearic to all containers
|
# Add other require item commands genearic to all containers
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
|
|
||||||
# Allows running as non-root user
|
|
||||||
RUN mkdir -p /.llama /.cache
|
|
||||||
|
|
||||||
RUN chmod -R g+rw /app /.llama /.cache
|
RUN chmod -R g+rw /app /.llama /.cache
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.distribution.distribution import (
|
||||||
builtin_automatically_routed_apis,
|
builtin_automatically_routed_apis,
|
||||||
get_provider_registry,
|
get_provider_registry,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
@ -73,11 +74,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
||||||
|
|
||||||
existing_providers = config.providers.get(api_str, [])
|
existing_providers = config.providers.get(api_str, [])
|
||||||
if existing_providers:
|
if existing_providers:
|
||||||
logger.info(
|
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
|
||||||
f"Re-configuring existing providers for API `{api_str}`...",
|
|
||||||
"green",
|
|
||||||
attrs=["bold"],
|
|
||||||
)
|
|
||||||
updated_providers = []
|
updated_providers = []
|
||||||
for p in existing_providers:
|
for p in existing_providers:
|
||||||
logger.info(f"> Configuring provider `({p.provider_type})`")
|
logger.info(f"> Configuring provider `({p.provider_type})`")
|
||||||
|
@ -91,7 +88,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
||||||
if not plist:
|
if not plist:
|
||||||
raise ValueError(f"No provider configured for API {api_str}?")
|
raise ValueError(f"No provider configured for API {api_str}?")
|
||||||
|
|
||||||
logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
logger.info(f"Configuring API `{api_str}`...")
|
||||||
updated_providers = []
|
updated_providers = []
|
||||||
for i, provider_type in enumerate(plist):
|
for i, provider_type in enumerate(plist):
|
||||||
if i >= 1:
|
if i >= 1:
|
||||||
|
@ -174,4 +171,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
|
||||||
|
|
||||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
|
|
||||||
|
if not config_dict.get("external_providers_dir", None):
|
||||||
|
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
||||||
|
|
||||||
return StackRunConfig(**config_dict)
|
return StackRunConfig(**config_dict)
|
||||||
|
|
|
@ -5,9 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
@ -249,10 +250,18 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to TLS key file for HTTPS",
|
description="Path to TLS key file for HTTPS",
|
||||||
)
|
)
|
||||||
|
tls_cafile: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Path to TLS CA file for HTTPS with mutual TLS authentication",
|
||||||
|
)
|
||||||
auth: AuthenticationConfig | None = Field(
|
auth: AuthenticationConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Authentication configuration for the server",
|
description="Authentication configuration for the server",
|
||||||
)
|
)
|
||||||
|
host: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="The host the server should listen on",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
@ -304,11 +313,20 @@ a default SQLite store will be used.""",
|
||||||
description="Configuration for the HTTP(S) server",
|
description="Configuration for the HTTP(S) server",
|
||||||
)
|
)
|
||||||
|
|
||||||
external_providers_dir: str | None = Field(
|
external_providers_dir: Path | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("external_providers_dir")
|
||||||
|
@classmethod
|
||||||
|
def validate_external_providers_dir(cls, v):
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if isinstance(v, str):
|
||||||
|
return Path(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||||
|
@ -322,8 +340,17 @@ class BuildConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Name of the distribution to build",
|
description="Name of the distribution to build",
|
||||||
)
|
)
|
||||||
external_providers_dir: str | None = Field(
|
external_providers_dir: Path | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||||
"pip_packages MUST contain the provider package name.",
|
"pip_packages MUST contain the provider package name.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("external_providers_dir")
|
||||||
|
@classmethod
|
||||||
|
def validate_external_providers_dir(cls, v):
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if isinstance(v, str):
|
||||||
|
return Path(v)
|
||||||
|
return v
|
||||||
|
|
|
@ -145,7 +145,7 @@ def get_provider_registry(
|
||||||
|
|
||||||
# Check if config has the external_providers_dir attribute
|
# Check if config has the external_providers_dir attribute
|
||||||
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||||
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||||
if not os.path.exists(external_providers_dir):
|
if not os.path.exists(external_providers_dir):
|
||||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||||
|
|
|
@ -30,7 +30,7 @@ from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.build import print_pip_install_help
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec
|
||||||
from llama_stack.distribution.request_headers import (
|
from llama_stack.distribution.request_headers import (
|
||||||
PROVIDER_DATA_VAR,
|
PROVIDER_DATA_VAR,
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
|
@ -216,7 +216,19 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
"yellow",
|
"yellow",
|
||||||
)
|
)
|
||||||
if self.config_path_or_template_name.endswith(".yaml"):
|
if self.config_path_or_template_name.endswith(".yaml"):
|
||||||
print_pip_install_help(self.config.providers)
|
# Convert Provider objects to their types
|
||||||
|
provider_types: dict[str, str | list[str]] = {}
|
||||||
|
for api, providers in self.config.providers.items():
|
||||||
|
types = [p.provider_type for p in providers]
|
||||||
|
# Convert single-item lists to strings
|
||||||
|
provider_types[api] = types[0] if len(types) == 1 else types
|
||||||
|
build_config = BuildConfig(
|
||||||
|
distribution_spec=DistributionSpec(
|
||||||
|
providers=provider_types,
|
||||||
|
),
|
||||||
|
external_providers_dir=self.config.external_providers_dir,
|
||||||
|
)
|
||||||
|
print_pip_install_help(build_config)
|
||||||
else:
|
else:
|
||||||
prefix = "!" if in_notebook() else ""
|
prefix = "!" if in_notebook() else ""
|
||||||
cprint(
|
cprint(
|
||||||
|
|
|
@ -99,7 +99,7 @@ class ProviderImpl(Providers):
|
||||||
try:
|
try:
|
||||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||||
return api_name, health
|
return api_name, health
|
||||||
except asyncio.TimeoutError:
|
except (asyncio.TimeoutError, TimeoutError):
|
||||||
return (
|
return (
|
||||||
api_name,
|
api_name,
|
||||||
HealthResponse(
|
HealthResponse(
|
||||||
|
|
|
@ -44,7 +44,8 @@ class RequestProviderDataContext(AbstractContextManager):
|
||||||
class NeedsRequestProviderData:
|
class NeedsRequestProviderData:
|
||||||
def get_request_provider_data(self) -> Any:
|
def get_request_provider_data(self) -> Any:
|
||||||
spec = self.__provider_spec__
|
spec = self.__provider_spec__
|
||||||
assert spec, f"Provider spec not set on {self.__class__}"
|
if not spec:
|
||||||
|
raise ValueError(f"Provider spec not set on {self.__class__}")
|
||||||
|
|
||||||
provider_type = spec.provider_type
|
provider_type = spec.provider_type
|
||||||
validator_class = spec.provider_data_validator
|
validator_class = spec.provider_data_validator
|
||||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.eval import Eval
|
from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference, InferenceProvider
|
||||||
from llama_stack.apis.inspect import Inspect
|
from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.apis.post_training import PostTraining
|
from llama_stack.apis.post_training import PostTraining
|
||||||
|
@ -83,6 +83,13 @@ def api_protocol_map() -> dict[Api, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
||||||
|
return {
|
||||||
|
**api_protocol_map(),
|
||||||
|
Api.inference: InferenceProvider,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def additional_protocols_map() -> dict[Api, Any]:
|
def additional_protocols_map() -> dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||||
|
@ -302,9 +309,6 @@ async def instantiate_provider(
|
||||||
inner_impls: dict[str, Any],
|
inner_impls: dict[str, Any],
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
):
|
):
|
||||||
protocols = api_protocol_map()
|
|
||||||
additional_protocols = additional_protocols_map()
|
|
||||||
|
|
||||||
provider_spec = provider.spec
|
provider_spec = provider.spec
|
||||||
if not hasattr(provider_spec, "module"):
|
if not hasattr(provider_spec, "module"):
|
||||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||||
|
@ -342,6 +346,8 @@ async def instantiate_provider(
|
||||||
impl.__provider_spec__ = provider_spec
|
impl.__provider_spec__ = provider_spec
|
||||||
impl.__provider_config__ = config
|
impl.__provider_config__ = config
|
||||||
|
|
||||||
|
protocols = api_protocol_map_for_compliance_check()
|
||||||
|
additional_protocols = additional_protocols_map()
|
||||||
# TODO: check compliance for special tool groups
|
# TODO: check compliance for special tool groups
|
||||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||||
|
|
|
@ -573,6 +573,12 @@ class InferenceRouter(Inference):
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||||
|
|
||||||
|
# Some providers make tool calls even when tool_choice is "none"
|
||||||
|
# so just clear them both out to avoid unexpected tool calls
|
||||||
|
if tool_choice == "none" and tools is not None:
|
||||||
|
tool_choice = None
|
||||||
|
tools = None
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model=model_obj.identifier,
|
model=model_obj.identifier,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -600,7 +606,19 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
|
if stream:
|
||||||
return await provider.openai_chat_completion(**params)
|
return await provider.openai_chat_completion(**params)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_openai_chat_completion(provider, params)
|
||||||
|
|
||||||
|
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||||
|
response = await provider.openai_chat_completion(**params)
|
||||||
|
for choice in response.choices:
|
||||||
|
# some providers return an empty list for no tool calls in non-streaming responses
|
||||||
|
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
||||||
|
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
|
||||||
|
choice.message.tool_calls = None
|
||||||
|
return response
|
||||||
|
|
||||||
async def health(self) -> dict[str, HealthResponse]:
|
async def health(self) -> dict[str, HealthResponse]:
|
||||||
health_statuses = {}
|
health_statuses = {}
|
||||||
|
@ -612,7 +630,7 @@ class InferenceRouter(Inference):
|
||||||
continue
|
continue
|
||||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||||
health_statuses[provider_id] = health
|
health_statuses[provider_id] = health
|
||||||
except asyncio.TimeoutError:
|
except (asyncio.TimeoutError, TimeoutError):
|
||||||
health_statuses[provider_id] = HealthResponse(
|
health_statuses[provider_id] = HealthResponse(
|
||||||
status=HealthStatus.ERROR,
|
status=HealthStatus.ERROR,
|
||||||
message=f"Health check timed out after {timeout} seconds",
|
message=f"Health check timed out after {timeout} seconds",
|
||||||
|
|
|
@ -93,7 +93,7 @@ class AuthenticationMiddleware:
|
||||||
|
|
||||||
# Validate token and get access attributes
|
# Validate token and get access attributes
|
||||||
try:
|
try:
|
||||||
access_attributes = await self.auth_provider.validate_token(token, scope)
|
validation_result = await self.auth_provider.validate_token(token, scope)
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
logger.exception("Authentication request timed out")
|
logger.exception("Authentication request timed out")
|
||||||
return await self._send_auth_error(send, "Authentication service timeout")
|
return await self._send_auth_error(send, "Authentication service timeout")
|
||||||
|
@ -105,17 +105,20 @@ class AuthenticationMiddleware:
|
||||||
return await self._send_auth_error(send, "Authentication service error")
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
||||||
# Store attributes in request scope for access control
|
# Store attributes in request scope for access control
|
||||||
if access_attributes:
|
if validation_result.access_attributes:
|
||||||
user_attributes = access_attributes.model_dump(exclude_none=True)
|
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
|
||||||
else:
|
else:
|
||||||
logger.warning("No access attributes, setting namespace to token by default")
|
logger.warning("No access attributes, setting namespace to token by default")
|
||||||
user_attributes = {
|
user_attributes = {
|
||||||
"namespaces": [token],
|
"roles": [token],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store attributes in request scope
|
# Store attributes in request scope
|
||||||
scope["user_attributes"] = user_attributes
|
scope["user_attributes"] = user_attributes
|
||||||
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
|
scope["principal"] = validation_result.principal
|
||||||
|
logger.debug(
|
||||||
|
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
|
||||||
|
)
|
||||||
|
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from urllib.parse import parse_qs
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, Field
|
from jose import jwt
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -18,9 +20,11 @@ from llama_stack.log import get_logger
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(BaseModel):
|
class TokenValidationResult(BaseModel):
|
||||||
"""The format of the authentication response from the auth endpoint."""
|
principal: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="The principal (username or persistent identifier) of the authenticated user",
|
||||||
|
)
|
||||||
access_attributes: AccessAttributes | None = Field(
|
access_attributes: AccessAttributes | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
|
@ -43,6 +47,10 @@ class AuthResponse(BaseModel):
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthResponse(TokenValidationResult):
|
||||||
|
"""The format of the authentication response from the auth endpoint."""
|
||||||
|
|
||||||
message: str | None = Field(
|
message: str | None = Field(
|
||||||
default=None, description="Optional message providing additional context about the authentication result."
|
default=None, description="Optional message providing additional context about the authentication result."
|
||||||
)
|
)
|
||||||
|
@ -69,6 +77,7 @@ class AuthProviderType(str, Enum):
|
||||||
|
|
||||||
KUBERNETES = "kubernetes"
|
KUBERNETES = "kubernetes"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
OAUTH2_TOKEN = "oauth2_token"
|
||||||
|
|
||||||
|
|
||||||
class AuthProviderConfig(BaseModel):
|
class AuthProviderConfig(BaseModel):
|
||||||
|
@ -82,7 +91,7 @@ class AuthProvider(ABC):
|
||||||
"""Abstract base class for authentication providers."""
|
"""Abstract base class for authentication providers."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
"""Validate a token and return access attributes."""
|
"""Validate a token and return access attributes."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -92,12 +101,16 @@ class AuthProvider(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class KubernetesAuthProviderConfig(BaseModel):
|
||||||
|
api_server_url: str
|
||||||
|
ca_cert_path: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class KubernetesAuthProvider(AuthProvider):
|
class KubernetesAuthProvider(AuthProvider):
|
||||||
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||||
|
|
||||||
def __init__(self, config: dict[str, str]):
|
def __init__(self, config: KubernetesAuthProviderConfig):
|
||||||
self.api_server_url = config["api_server_url"]
|
self.config = config
|
||||||
self.ca_cert_path = config.get("ca_cert_path")
|
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
async def _get_client(self):
|
async def _get_client(self):
|
||||||
|
@ -110,16 +123,16 @@ class KubernetesAuthProvider(AuthProvider):
|
||||||
|
|
||||||
# Configure the client
|
# Configure the client
|
||||||
configuration = client.Configuration()
|
configuration = client.Configuration()
|
||||||
configuration.host = self.api_server_url
|
configuration.host = self.config.api_server_url
|
||||||
if self.ca_cert_path:
|
if self.config.ca_cert_path:
|
||||||
configuration.ssl_ca_cert = self.ca_cert_path
|
configuration.ssl_ca_cert = self.config.ca_cert_path
|
||||||
configuration.verify_ssl = bool(self.ca_cert_path)
|
configuration.verify_ssl = bool(self.config.ca_cert_path)
|
||||||
|
|
||||||
# Create API client
|
# Create API client
|
||||||
self._client = ApiClient(configuration)
|
self._client = ApiClient(configuration)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
"""Validate a Kubernetes token and return access attributes."""
|
"""Validate a Kubernetes token and return access attributes."""
|
||||||
try:
|
try:
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
|
@ -146,9 +159,12 @@ class KubernetesAuthProvider(AuthProvider):
|
||||||
username = payload.get("sub", "")
|
username = payload.get("sub", "")
|
||||||
groups = payload.get("groups", [])
|
groups = payload.get("groups", [])
|
||||||
|
|
||||||
return AccessAttributes(
|
return TokenValidationResult(
|
||||||
|
principal=username,
|
||||||
|
access_attributes=AccessAttributes(
|
||||||
roles=[username], # Use username as a role
|
roles=[username], # Use username as a role
|
||||||
teams=groups, # Use Kubernetes groups as teams
|
teams=groups, # Use Kubernetes groups as teams
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -162,18 +178,125 @@ class KubernetesAuthProvider(AuthProvider):
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
||||||
|
attributes = AccessAttributes()
|
||||||
|
for claim_key, attribute_key in mapping.items():
|
||||||
|
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
||||||
|
continue
|
||||||
|
claim = claims[claim_key]
|
||||||
|
if isinstance(claim, list):
|
||||||
|
values = claim
|
||||||
|
else:
|
||||||
|
values = claim.split()
|
||||||
|
|
||||||
|
current = getattr(attributes, attribute_key)
|
||||||
|
if current:
|
||||||
|
current.extend(values)
|
||||||
|
else:
|
||||||
|
setattr(attributes, attribute_key, values)
|
||||||
|
return attributes
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||||
|
# The JWKS URI for collecting public keys
|
||||||
|
jwks_uri: str
|
||||||
|
cache_ttl: int = 3600
|
||||||
|
audience: str = "llama-stack"
|
||||||
|
claims_mapping: dict[str, str] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"sub": "roles",
|
||||||
|
"username": "roles",
|
||||||
|
"groups": "teams",
|
||||||
|
"team": "teams",
|
||||||
|
"project": "projects",
|
||||||
|
"tenant": "namespaces",
|
||||||
|
"namespace": "namespaces",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@field_validator("claims_mapping")
|
||||||
|
def validate_claims_mapping(cls, v):
|
||||||
|
for key, value in v.items():
|
||||||
|
if not value:
|
||||||
|
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||||
|
if value not in AccessAttributes.model_fields:
|
||||||
|
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
|
"""
|
||||||
|
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
||||||
|
|
||||||
|
This should be the standard authentication provider for most use cases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: OAuth2TokenAuthProviderConfig):
|
||||||
|
self.config = config
|
||||||
|
self._jwks_at: float = 0.0
|
||||||
|
self._jwks: dict[str, str] = {}
|
||||||
|
|
||||||
|
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
|
"""Validate a token using the JWT token."""
|
||||||
|
await self._refresh_jwks()
|
||||||
|
|
||||||
|
try:
|
||||||
|
header = jwt.get_unverified_header(token)
|
||||||
|
kid = header["kid"]
|
||||||
|
if kid not in self._jwks:
|
||||||
|
raise ValueError(f"Unknown key ID: {kid}")
|
||||||
|
key_data = self._jwks[kid]
|
||||||
|
algorithm = header.get("alg", "RS256")
|
||||||
|
claims = jwt.decode(
|
||||||
|
token,
|
||||||
|
key_data,
|
||||||
|
algorithms=[algorithm],
|
||||||
|
audience=self.config.audience,
|
||||||
|
options={"verify_exp": True},
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(f"Invalid JWT token: {token}") from exc
|
||||||
|
|
||||||
|
# There are other standard claims, the most relevant of which is `scope`.
|
||||||
|
# We should incorporate these into the access attributes.
|
||||||
|
principal = claims["sub"]
|
||||||
|
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||||
|
return TokenValidationResult(
|
||||||
|
principal=principal,
|
||||||
|
access_attributes=access_attributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
|
||||||
|
async def _refresh_jwks(self) -> None:
|
||||||
|
if time.time() - self._jwks_at > self.config.cache_ttl:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
res = await client.get(self.config.jwks_uri, timeout=5)
|
||||||
|
res.raise_for_status()
|
||||||
|
jwks_data = res.json()["keys"]
|
||||||
|
self._jwks = {}
|
||||||
|
for k in jwks_data:
|
||||||
|
kid = k["kid"]
|
||||||
|
# Store the entire key object as it may be needed for different algorithms
|
||||||
|
self._jwks[kid] = k
|
||||||
|
self._jwks_at = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAuthProviderConfig(BaseModel):
|
||||||
|
endpoint: str
|
||||||
|
|
||||||
|
|
||||||
class CustomAuthProvider(AuthProvider):
|
class CustomAuthProvider(AuthProvider):
|
||||||
"""Custom authentication provider that uses an external endpoint."""
|
"""Custom authentication provider that uses an external endpoint."""
|
||||||
|
|
||||||
def __init__(self, config: dict[str, str]):
|
def __init__(self, config: CustomAuthProviderConfig):
|
||||||
self.endpoint = config["endpoint"]
|
self.config = config
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||||
"""Validate a token using the custom authentication endpoint."""
|
"""Validate a token using the custom authentication endpoint."""
|
||||||
if not self.endpoint:
|
|
||||||
raise ValueError("Authentication endpoint not configured")
|
|
||||||
|
|
||||||
if scope is None:
|
if scope is None:
|
||||||
scope = {}
|
scope = {}
|
||||||
|
|
||||||
|
@ -202,7 +325,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self.endpoint,
|
self.config.endpoint,
|
||||||
json=auth_request.model_dump(),
|
json=auth_request.model_dump(),
|
||||||
timeout=10.0, # Add a reasonable timeout
|
timeout=10.0, # Add a reasonable timeout
|
||||||
)
|
)
|
||||||
|
@ -214,19 +337,7 @@ class CustomAuthProvider(AuthProvider):
|
||||||
try:
|
try:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
auth_response = AuthResponse(**response_data)
|
auth_response = AuthResponse(**response_data)
|
||||||
|
return auth_response
|
||||||
# Store attributes in request scope for access control
|
|
||||||
if auth_response.access_attributes:
|
|
||||||
return auth_response.access_attributes
|
|
||||||
else:
|
|
||||||
logger.warning("No access attributes, setting namespace to api_key by default")
|
|
||||||
user_attributes = {
|
|
||||||
"namespaces": [token],
|
|
||||||
}
|
|
||||||
|
|
||||||
scope["user_attributes"] = user_attributes
|
|
||||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
|
||||||
return auth_response.access_attributes
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error parsing authentication response")
|
logger.exception("Error parsing authentication response")
|
||||||
raise ValueError("Invalid authentication response format") from e
|
raise ValueError("Invalid authentication response format") from e
|
||||||
|
@ -253,9 +364,11 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
|
||||||
provider_type = config.provider_type.lower()
|
provider_type = config.provider_type.lower()
|
||||||
|
|
||||||
if provider_type == "kubernetes":
|
if provider_type == "kubernetes":
|
||||||
return KubernetesAuthProvider(config.config)
|
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
|
||||||
elif provider_type == "custom":
|
elif provider_type == "custom":
|
||||||
return CustomAuthProvider(config.config)
|
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||||
|
elif provider_type == "oauth2_token":
|
||||||
|
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||||
else:
|
else:
|
||||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
||||||
|
|
|
@ -9,6 +9,7 @@ import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -17,6 +18,7 @@ from importlib.metadata import version as parse_version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
import rich.pretty
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request
|
from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
|
@ -114,7 +116,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
||||||
return HTTPException(status_code=400, detail=str(exc))
|
return HTTPException(status_code=400, detail=str(exc))
|
||||||
elif isinstance(exc, PermissionError):
|
elif isinstance(exc, PermissionError):
|
||||||
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
||||||
elif isinstance(exc, TimeoutError):
|
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
||||||
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
||||||
elif isinstance(exc, NotImplementedError):
|
elif isinstance(exc, NotImplementedError):
|
||||||
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
||||||
|
@ -139,7 +141,7 @@ async def shutdown(app):
|
||||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||||
else:
|
else:
|
||||||
logger.warning("No shutdown method for %s", impl_name)
|
logger.warning("No shutdown method for %s", impl_name)
|
||||||
except asyncio.TimeoutError:
|
except (asyncio.TimeoutError, TimeoutError):
|
||||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||||
|
@ -186,11 +188,30 @@ async def sse_generator(event_gen_coroutine):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def log_request_pre_validation(request: Request):
|
||||||
|
if request.method in ("POST", "PUT", "PATCH"):
|
||||||
|
try:
|
||||||
|
body_bytes = await request.body()
|
||||||
|
if body_bytes:
|
||||||
|
try:
|
||||||
|
parsed_body = json.loads(body_bytes.decode())
|
||||||
|
log_output = rich.pretty.pretty_repr(parsed_body)
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
|
log_output = repr(body_bytes)
|
||||||
|
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
# Get auth attributes from the request scope
|
# Get auth attributes from the request scope
|
||||||
user_attributes = request.scope.get("user_attributes", {})
|
user_attributes = request.scope.get("user_attributes", {})
|
||||||
|
|
||||||
|
await log_request_pre_validation(request)
|
||||||
|
|
||||||
# Use context manager with both provider data and auth attributes
|
# Use context manager with both provider data and auth attributes
|
||||||
with request_provider_data_context(request.headers, user_attributes):
|
with request_provider_data_context(request.headers, user_attributes):
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
|
@ -337,22 +358,11 @@ def main(args: argparse.Namespace | None = None):
|
||||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||||
help="Port to listen on",
|
help="Port to listen on",
|
||||||
)
|
)
|
||||||
parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--env",
|
"--env",
|
||||||
action="append",
|
action="append",
|
||||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--tls-keyfile",
|
|
||||||
help="Path to TLS key file for HTTPS",
|
|
||||||
required="--tls-certfile" in sys.argv,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tls-certfile",
|
|
||||||
help="Path to TLS certificate file for HTTPS",
|
|
||||||
required="--tls-keyfile" in sys.argv,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||||
|
@ -361,9 +371,9 @@ def main(args: argparse.Namespace | None = None):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Check for deprecated argument usage
|
# Check for deprecated argument usage
|
||||||
if "--yaml-config" in sys.argv:
|
if "--config" in sys.argv:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The '--yaml-config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
|
"The '--config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
@ -381,7 +391,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
raise ValueError(f"Template {args.template} does not exist")
|
raise ValueError(f"Template {args.template} does not exist")
|
||||||
log_line = f"Using template {args.template} config file: {config_file}"
|
log_line = f"Using template {args.template} config file: {config_file}"
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either --yaml-config or --template must be provided")
|
raise ValueError("Either --config or --template must be provided")
|
||||||
|
|
||||||
logger_config = None
|
logger_config = None
|
||||||
with open(config_file) as fp:
|
with open(config_file) as fp:
|
||||||
|
@ -486,10 +496,6 @@ def main(args: argparse.Namespace | None = None):
|
||||||
port = args.port or config.server.port
|
port = args.port or config.server.port
|
||||||
|
|
||||||
ssl_config = None
|
ssl_config = None
|
||||||
if args.tls_keyfile:
|
|
||||||
keyfile = args.tls_keyfile
|
|
||||||
certfile = args.tls_certfile
|
|
||||||
else:
|
|
||||||
keyfile = config.server.tls_keyfile
|
keyfile = config.server.tls_keyfile
|
||||||
certfile = config.server.tls_certfile
|
certfile = config.server.tls_certfile
|
||||||
|
|
||||||
|
@ -498,9 +504,16 @@ def main(args: argparse.Namespace | None = None):
|
||||||
"ssl_keyfile": keyfile,
|
"ssl_keyfile": keyfile,
|
||||||
"ssl_certfile": certfile,
|
"ssl_certfile": certfile,
|
||||||
}
|
}
|
||||||
|
if config.server.tls_cafile:
|
||||||
|
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||||
|
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||||
|
logger.info(
|
||||||
|
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||||
|
|
||||||
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
listen_host = config.server.host or ["::", "0.0.0.0"]
|
||||||
logger.info(f"Listening on {listen_host}:{port}")
|
logger.info(f"Listening on {listen_host}:{port}")
|
||||||
|
|
||||||
uvicorn_config = {
|
uvicorn_config = {
|
||||||
|
|
|
@ -29,7 +29,7 @@ error_handler() {
|
||||||
trap 'error_handler ${LINENO}' ERR
|
trap 'error_handler ${LINENO}' ERR
|
||||||
|
|
||||||
if [ $# -lt 3 ]; then
|
if [ $# -lt 3 ]; then
|
||||||
echo "Usage: $0 <env_type> <env_path_or_name> <yaml_config> <port> <script_args...>"
|
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -40,23 +40,30 @@ env_path_or_name="$1"
|
||||||
container_image="localhost/$env_path_or_name"
|
container_image="localhost/$env_path_or_name"
|
||||||
shift
|
shift
|
||||||
|
|
||||||
yaml_config="$1"
|
|
||||||
shift
|
|
||||||
|
|
||||||
port="$1"
|
port="$1"
|
||||||
shift
|
shift
|
||||||
|
|
||||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||||
source "$SCRIPT_DIR/common.sh"
|
source "$SCRIPT_DIR/common.sh"
|
||||||
|
|
||||||
# Initialize env_vars as an string
|
# Initialize variables
|
||||||
|
yaml_config=""
|
||||||
env_vars=""
|
env_vars=""
|
||||||
other_args=""
|
other_args=""
|
||||||
# Process environment variables from --env arguments
|
|
||||||
|
# Process remaining arguments
|
||||||
while [[ $# -gt 0 ]]; do
|
while [[ $# -gt 0 ]]; do
|
||||||
case "$1" in
|
case "$1" in
|
||||||
|
--config)
|
||||||
|
if [[ -n "$2" ]]; then
|
||||||
|
yaml_config="$2"
|
||||||
|
shift 2
|
||||||
|
else
|
||||||
|
echo -e "${RED}Error: $1 requires a CONFIG argument${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
;;
|
||||||
--env)
|
--env)
|
||||||
|
|
||||||
if [[ -n "$2" ]]; then
|
if [[ -n "$2" ]]; then
|
||||||
env_vars="$env_vars --env $2"
|
env_vars="$env_vars --env $2"
|
||||||
shift 2
|
shift 2
|
||||||
|
@ -71,6 +78,13 @@ while [[ $# -gt 0 ]]; do
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
done
|
done
|
||||||
|
|
||||||
|
# Check if yaml_config is required based on env_type
|
||||||
|
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]] && [ -z "$yaml_config" ]; then
|
||||||
|
echo -e "${RED}Error: --config is required for venv and conda environments${NC}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
PYTHON_BINARY="python"
|
PYTHON_BINARY="python"
|
||||||
case "$env_type" in
|
case "$env_type" in
|
||||||
"venv")
|
"venv")
|
||||||
|
@ -106,8 +120,14 @@ esac
|
||||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||||
set -x
|
set -x
|
||||||
|
|
||||||
|
if [ -n "$yaml_config" ]; then
|
||||||
|
yaml_config_arg="--config $yaml_config"
|
||||||
|
else
|
||||||
|
yaml_config_arg=""
|
||||||
|
fi
|
||||||
|
|
||||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||||
--yaml-config "$yaml_config" \
|
$yaml_config_arg \
|
||||||
--port "$port" \
|
--port "$port" \
|
||||||
$env_vars \
|
$env_vars \
|
||||||
$other_args
|
$other_args
|
||||||
|
@ -149,15 +169,26 @@ elif [[ "$env_type" == "container" ]]; then
|
||||||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||||
fi
|
fi
|
||||||
|
|
||||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
# Build the command with optional yaml config
|
||||||
|
cmd="$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||||
-p $port:$port \
|
-p $port:$port \
|
||||||
$env_vars \
|
$env_vars \
|
||||||
-v "$yaml_config:/app/config.yaml" \
|
|
||||||
$mounts \
|
$mounts \
|
||||||
--env LLAMA_STACK_PORT=$port \
|
--env LLAMA_STACK_PORT=$port \
|
||||||
--entrypoint python \
|
--entrypoint python \
|
||||||
$container_image:$version_tag \
|
$container_image:$version_tag \
|
||||||
-m llama_stack.distribution.server.server \
|
-m llama_stack.distribution.server.server"
|
||||||
--yaml-config /app/config.yaml \
|
|
||||||
$other_args
|
# Add yaml config if provided, otherwise use default
|
||||||
|
if [ -n "$yaml_config" ]; then
|
||||||
|
cmd="$cmd -v $yaml_config:/app/run.yaml --config /app/run.yaml"
|
||||||
|
else
|
||||||
|
cmd="$cmd --config /app/run.yaml"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Add any other args
|
||||||
|
cmd="$cmd $other_args"
|
||||||
|
|
||||||
|
# Execute the command
|
||||||
|
eval $cmd
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -73,7 +73,7 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
|
|
||||||
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
||||||
start_key, end_key = _get_registry_key_range()
|
start_key, end_key = _get_registry_key_range()
|
||||||
values = await self.kvstore.range(start_key, end_key)
|
values = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
return _parse_registry_values(values)
|
return _parse_registry_values(values)
|
||||||
|
|
||||||
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
|
@ -134,7 +134,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
return
|
return
|
||||||
|
|
||||||
start_key, end_key = _get_registry_key_range()
|
start_key, end_key = _get_registry_key_range()
|
||||||
values = await self.kvstore.range(start_key, end_key)
|
values = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
objects = _parse_registry_values(values)
|
objects = _parse_registry_values(values)
|
||||||
|
|
||||||
async with self._locked_cache() as cache:
|
async with self._locked_cache() as cache:
|
||||||
|
|
|
@ -124,7 +124,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
||||||
message_placeholder.markdown(full_response + "▌")
|
message_placeholder.markdown(full_response + "▌")
|
||||||
message_placeholder.markdown(full_response)
|
message_placeholder.markdown(full_response)
|
||||||
else:
|
else:
|
||||||
full_response = response
|
full_response = response.completion_message.content
|
||||||
message_placeholder.markdown(full_response.completion_message.content)
|
message_placeholder.markdown(full_response)
|
||||||
|
|
||||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||||
|
|
|
@ -14,3 +14,5 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
||||||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||||
|
|
||||||
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
|
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
|
||||||
|
|
||||||
|
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"
|
||||||
|
|
|
@ -22,8 +22,10 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||||
|
|
||||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||||
env_name = ""
|
env_name = ""
|
||||||
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
|
if image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
env_name = f"distribution-{template_name}" if template_name else config.container_image
|
env_name = (
|
||||||
|
f"distribution-{template_name}" if template_name else (config.container_image if config else image_name)
|
||||||
|
)
|
||||||
elif image_type == LlamaStackImageType.CONDA.value:
|
elif image_type == LlamaStackImageType.CONDA.value:
|
||||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||||
env_name = image_name or current_conda_env
|
env_name = image_name or current_conda_env
|
||||||
|
|
|
@ -245,7 +245,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
{"function_description": self._gen_function_description(custom_tools)},
|
{"function_description": self._gen_function_description(custom_tools)},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> str:
|
||||||
template_str = textwrap.dedent(
|
template_str = textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
Here is a list of functions in JSON format that you can invoke.
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
@ -286,10 +286,12 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
return PromptTemplate(
|
template = PromptTemplate(
|
||||||
template_str.strip("\n"),
|
template_str.strip("\n"),
|
||||||
{"tools": [t.model_dump() for t in custom_tools]},
|
{"tools": [t.model_dump() for t in custom_tools]},
|
||||||
).render()
|
)
|
||||||
|
rendered: str = template.render()
|
||||||
|
return rendered
|
||||||
|
|
||||||
def data_examples(self) -> list[list[ToolDefinition]]:
|
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||||
return [
|
return [
|
||||||
|
|
|
@ -173,9 +173,7 @@ INCORRECT: [get_events(location="Singapore")] <- If function not in list
|
||||||
- Don't repeat tool response verbatim
|
- Don't repeat tool response verbatim
|
||||||
- Don't add supplementary information
|
- Don't add supplementary information
|
||||||
|
|
||||||
|
Here is a list of functions in JSON format that you can invoke:
|
||||||
Here is a list of functions in JSON format that you can invoke.
|
|
||||||
|
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"name": "get_weather",
|
"name": "get_weather",
|
||||||
|
@ -196,10 +194,7 @@ Here is a list of functions in JSON format that you can invoke.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]<|eot|><|header_start|>user<|header_end|>
|
||||||
|
|
||||||
You can answer general questions or invoke tools when necessary.
|
|
||||||
In addition to tool calls, you should also augment your responses by using the tool outputs.<|eot|><|header_start|>user<|header_end|>
|
|
||||||
|
|
||||||
What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|>
|
What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|>
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,6 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
- Don't repeat tool response verbatim
|
- Don't repeat tool response verbatim
|
||||||
- Don't add supplementary information
|
- Don't add supplementary information
|
||||||
|
|
||||||
|
|
||||||
{{ function_description }}
|
{{ function_description }}
|
||||||
""".strip("\n")
|
""".strip("\n")
|
||||||
)
|
)
|
||||||
|
@ -76,8 +75,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
||||||
template_str = textwrap.dedent(
|
template_str = textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
Here is a list of functions in JSON format that you can invoke.
|
Here is a list of functions in JSON format that you can invoke:
|
||||||
|
|
||||||
[
|
[
|
||||||
{% for t in tools -%}
|
{% for t in tools -%}
|
||||||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||||
|
@ -108,10 +106,6 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
]
|
]
|
||||||
|
|
||||||
You can answer general questions or invoke tools when necessary.
|
|
||||||
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
return PromptTemplate(
|
return PromptTemplate(
|
||||||
|
|
|
@ -948,6 +948,8 @@ def llama_meta_net_info(model: Model) -> LlamaDownloadInfo:
|
||||||
elif model.core_model_id == CoreModelId.llama_guard_2_8b:
|
elif model.core_model_id == CoreModelId.llama_guard_2_8b:
|
||||||
folder = "llama-guard-2"
|
folder = "llama-guard-2"
|
||||||
else:
|
else:
|
||||||
|
if model.huggingface_repo is None:
|
||||||
|
raise ValueError(f"Model {model.core_model_id} has no huggingface_repo set")
|
||||||
folder = model.huggingface_repo.split("/")[-1]
|
folder = model.huggingface_repo.split("/")[-1]
|
||||||
if "Llama-2" in folder:
|
if "Llama-2" in folder:
|
||||||
folder = folder.lower()
|
folder = folder.lower()
|
||||||
|
@ -1024,3 +1026,4 @@ def llama_meta_pth_size(model: Model) -> int:
|
||||||
return 54121549657
|
return 54121549657
|
||||||
else:
|
else:
|
||||||
return 100426653046
|
return 100426653046
|
||||||
|
return 0
|
||||||
|
|
|
@ -95,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
vector_io_api: VectorIO,
|
vector_io_api: VectorIO,
|
||||||
persistence_store: KVStore,
|
persistence_store: KVStore,
|
||||||
|
created_at: str,
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
|
@ -104,6 +105,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
|
self.created_at = created_at
|
||||||
|
|
||||||
ShieldRunnerMixin.__init__(
|
ShieldRunnerMixin.__init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -4,10 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
Agent,
|
Agent,
|
||||||
|
@ -20,14 +20,13 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
AgentTurnResumeRequest,
|
AgentTurnResumeRequest,
|
||||||
Document,
|
Document,
|
||||||
ListAgentSessionsResponse,
|
OpenAIResponseInput,
|
||||||
ListAgentsResponse,
|
|
||||||
OpenAIResponseInputMessage,
|
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
@ -39,13 +38,14 @@ from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
|
|
||||||
from .agent_instance import ChatAgent
|
from .agent_instance import ChatAgent
|
||||||
from .config import MetaReferenceAgentsImplConfig
|
from .config import MetaReferenceAgentsImplConfig
|
||||||
from .openai_responses import OpenAIResponsesImpl
|
from .openai_responses import OpenAIResponsesImpl
|
||||||
|
from .persistence import AgentInfo
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImpl(Agents):
|
class MetaReferenceAgentsImpl(Agents):
|
||||||
|
@ -82,43 +82,47 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
) -> AgentCreateResponse:
|
) -> AgentCreateResponse:
|
||||||
agent_id = str(uuid.uuid4())
|
agent_id = str(uuid.uuid4())
|
||||||
|
created_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
agent_info = AgentInfo(
|
||||||
|
**agent_config.model_dump(),
|
||||||
|
created_at=created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the agent info
|
||||||
await self.persistence_store.set(
|
await self.persistence_store.set(
|
||||||
key=f"agent:{agent_id}",
|
key=f"agent:{agent_id}",
|
||||||
value=agent_config.model_dump_json(),
|
value=agent_info.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentCreateResponse(
|
return AgentCreateResponse(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||||
agent_config = await self.persistence_store.get(
|
agent_info_json = await self.persistence_store.get(
|
||||||
key=f"agent:{agent_id}",
|
key=f"agent:{agent_id}",
|
||||||
)
|
)
|
||||||
if not agent_config:
|
if not agent_info_json:
|
||||||
raise ValueError(f"Could not find agent config for {agent_id}")
|
raise ValueError(f"Could not find agent info for {agent_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_config = json.loads(agent_config)
|
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent_config = AgentConfig(**agent_config)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
|
raise ValueError(f"Could not validate agent info for {agent_id}") from e
|
||||||
|
|
||||||
return ChatAgent(
|
return ChatAgent(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
agent_config=agent_config,
|
agent_config=agent_info,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
vector_io_api=self.vector_io_api,
|
vector_io_api=self.vector_io_api,
|
||||||
tool_runtime_api=self.tool_runtime_api,
|
tool_runtime_api=self.tool_runtime_api,
|
||||||
tool_groups_api=self.tool_groups_api,
|
tool_groups_api=self.tool_groups_api,
|
||||||
persistence_store=(
|
persistence_store=(
|
||||||
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
|
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||||
),
|
),
|
||||||
|
created_at=agent_info.created_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent_session(
|
async def create_agent_session(
|
||||||
|
@ -212,6 +216,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
turn_ids: list[str] | None = None,
|
turn_ids: list[str] | None = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
agent = await self._get_agent_impl(agent_id)
|
agent = await self._get_agent_impl(agent_id)
|
||||||
|
|
||||||
session_info = await agent.storage.get_session_info(session_id)
|
session_info = await agent.storage.get_session_info(session_id)
|
||||||
if session_info is None:
|
if session_info is None:
|
||||||
raise ValueError(f"Session {session_id} not found")
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
@ -226,24 +231,75 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||||
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
|
agent = await self._get_agent_impl(agent_id)
|
||||||
|
session_info = await agent.storage.get_session_info(session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
# Delete turns first, then the session
|
||||||
|
await agent.storage.delete_session_turns(session_id)
|
||||||
|
await agent.storage.delete_session(session_id)
|
||||||
|
|
||||||
async def delete_agent(self, agent_id: str) -> None:
|
async def delete_agent(self, agent_id: str) -> None:
|
||||||
|
# First get all sessions for this agent
|
||||||
|
agent = await self._get_agent_impl(agent_id)
|
||||||
|
sessions = await agent.storage.list_sessions()
|
||||||
|
|
||||||
|
# Delete all sessions
|
||||||
|
for session in sessions:
|
||||||
|
await self.delete_agents_session(agent_id, session.session_id)
|
||||||
|
|
||||||
|
# Finally delete the agent itself
|
||||||
await self.persistence_store.delete(f"agent:{agent_id}")
|
await self.persistence_store.delete(f"agent:{agent_id}")
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||||
pass
|
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
|
||||||
|
agent_list: list[Agent] = []
|
||||||
|
for agent_key in agent_keys:
|
||||||
|
agent_id = agent_key.split(":")[1]
|
||||||
|
|
||||||
async def list_agents(self) -> ListAgentsResponse:
|
# Get the agent info using the key
|
||||||
pass
|
agent_info_json = await self.persistence_store.get(agent_key)
|
||||||
|
if not agent_info_json:
|
||||||
|
logger.error(f"Could not find agent info for key {agent_key}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||||
|
agent_list.append(
|
||||||
|
Agent(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_config=agent_info,
|
||||||
|
created_at=agent_info.created_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing agent info for {agent_id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert Agent objects to dictionaries
|
||||||
|
agent_dicts = [agent.model_dump() for agent in agent_list]
|
||||||
|
return paginate_records(agent_dicts, start_index, limit)
|
||||||
|
|
||||||
async def get_agent(self, agent_id: str) -> Agent:
|
async def get_agent(self, agent_id: str) -> Agent:
|
||||||
pass
|
chat_agent = await self._get_agent_impl(agent_id)
|
||||||
|
agent = Agent(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_config=chat_agent.agent_config,
|
||||||
|
created_at=chat_agent.created_at,
|
||||||
|
)
|
||||||
|
return agent
|
||||||
|
|
||||||
async def list_agent_sessions(
|
async def list_agent_sessions(
|
||||||
self,
|
self, agent_id: str, start_index: int | None = None, limit: int | None = None
|
||||||
agent_id: str,
|
) -> PaginatedResponse:
|
||||||
) -> ListAgentSessionsResponse:
|
agent = await self._get_agent_impl(agent_id)
|
||||||
|
sessions = await agent.storage.list_sessions()
|
||||||
|
# Convert Session objects to dictionaries
|
||||||
|
session_dicts = [session.model_dump() for session in sessions]
|
||||||
|
return paginate_records(session_dicts, start_index, limit)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# OpenAI responses
|
# OpenAI responses
|
||||||
|
@ -255,7 +311,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
self,
|
self,
|
||||||
input: str | list[OpenAIResponseInputMessage],
|
input: str | list[OpenAIResponseInput],
|
||||||
model: str,
|
model: str,
|
||||||
previous_response_id: str | None = None,
|
previous_response_id: str | None = None,
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
|
|
|
@ -7,22 +7,29 @@
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseInputMessage,
|
OpenAIResponseInput,
|
||||||
|
OpenAIResponseInputFunctionToolCallOutput,
|
||||||
|
OpenAIResponseInputItemList,
|
||||||
|
OpenAIResponseInputMessageContent,
|
||||||
OpenAIResponseInputMessageContentImage,
|
OpenAIResponseInputMessageContentImage,
|
||||||
OpenAIResponseInputMessageContentText,
|
OpenAIResponseInputMessageContentText,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponseInputToolFunction,
|
||||||
|
OpenAIResponseMessage,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
OpenAIResponseObjectStreamResponseCompleted,
|
OpenAIResponseObjectStreamResponseCompleted,
|
||||||
OpenAIResponseObjectStreamResponseCreated,
|
OpenAIResponseObjectStreamResponseCreated,
|
||||||
OpenAIResponseOutput,
|
OpenAIResponseOutput,
|
||||||
OpenAIResponseOutputMessage,
|
OpenAIResponseOutputMessageContent,
|
||||||
OpenAIResponseOutputMessageContentOutputText,
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
|
@ -32,10 +39,13 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartParam,
|
OpenAIChatCompletionContentPartParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
OpenAIChatCompletionToolCall,
|
||||||
OpenAIChatCompletionToolCallFunction,
|
OpenAIChatCompletionToolCallFunction,
|
||||||
OpenAIChoice,
|
OpenAIChoice,
|
||||||
|
OpenAIDeveloperMessageParam,
|
||||||
OpenAIImageURL,
|
OpenAIImageURL,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
OpenAISystemMessageParam,
|
||||||
OpenAIToolMessageParam,
|
OpenAIToolMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
)
|
)
|
||||||
|
@ -50,31 +60,110 @@ logger = get_logger(name=__name__, category="openai_responses")
|
||||||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||||
|
|
||||||
|
|
||||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
|
async def _convert_response_content_to_chat_content(
|
||||||
|
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
|
||||||
|
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||||
|
"""
|
||||||
|
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||||
|
|
||||||
|
The content schemas of each API look similar, but are not exactly the same.
|
||||||
|
"""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
|
||||||
|
converted_parts = []
|
||||||
|
for content_part in content:
|
||||||
|
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||||
|
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||||
|
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||||
|
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||||
|
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||||
|
if content_part.image_url:
|
||||||
|
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||||
|
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||||
|
elif isinstance(content_part, str):
|
||||||
|
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
|
||||||
|
)
|
||||||
|
return converted_parts
|
||||||
|
|
||||||
|
|
||||||
|
async def _convert_response_input_to_chat_messages(
|
||||||
|
input: str | list[OpenAIResponseInput],
|
||||||
|
) -> list[OpenAIMessageParam]:
|
||||||
|
"""
|
||||||
|
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||||
|
"""
|
||||||
messages: list[OpenAIMessageParam] = []
|
messages: list[OpenAIMessageParam] = []
|
||||||
for output_message in previous_response.output:
|
if isinstance(input, list):
|
||||||
if isinstance(output_message, OpenAIResponseOutputMessage):
|
for input_item in input:
|
||||||
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||||
|
messages.append(
|
||||||
|
OpenAIToolMessageParam(
|
||||||
|
content=input_item.output,
|
||||||
|
tool_call_id=input_item.call_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||||
|
tool_call = OpenAIChatCompletionToolCall(
|
||||||
|
index=0,
|
||||||
|
id=input_item.call_id,
|
||||||
|
function=OpenAIChatCompletionToolCallFunction(
|
||||||
|
name=input_item.name,
|
||||||
|
arguments=input_item.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||||
|
else:
|
||||||
|
content = await _convert_response_content_to_chat_content(input_item.content)
|
||||||
|
message_type = await _get_message_type_by_role(input_item.role)
|
||||||
|
if message_type is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||||
|
)
|
||||||
|
messages.append(message_type(content=content))
|
||||||
|
else:
|
||||||
|
messages.append(OpenAIUserMessageParam(content=input))
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
|
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
||||||
output_messages = []
|
"""
|
||||||
for choice in choices:
|
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
|
||||||
|
"""
|
||||||
output_content = ""
|
output_content = ""
|
||||||
if isinstance(choice.message.content, str):
|
if isinstance(choice.message.content, str):
|
||||||
output_content = choice.message.content
|
output_content = choice.message.content
|
||||||
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||||
output_content = choice.message.content.text
|
output_content = choice.message.content.text
|
||||||
# TODO: handle image content
|
else:
|
||||||
output_messages.append(
|
raise ValueError(
|
||||||
OpenAIResponseOutputMessage(
|
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIResponseMessage(
|
||||||
id=f"msg_{uuid.uuid4()}",
|
id=f"msg_{uuid.uuid4()}",
|
||||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||||
status="completed",
|
status="completed",
|
||||||
|
role="assistant",
|
||||||
)
|
)
|
||||||
)
|
|
||||||
return output_messages
|
|
||||||
|
async def _get_message_type_by_role(role: str):
|
||||||
|
role_to_type = {
|
||||||
|
"user": OpenAIUserMessageParam,
|
||||||
|
"system": OpenAISystemMessageParam,
|
||||||
|
"assistant": OpenAIAssistantMessageParam,
|
||||||
|
"developer": OpenAIDeveloperMessageParam,
|
||||||
|
}
|
||||||
|
return role_to_type.get(role)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||||
|
input_items: OpenAIResponseInputItemList
|
||||||
|
response: OpenAIResponseObject
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponsesImpl:
|
class OpenAIResponsesImpl:
|
||||||
|
@ -90,19 +179,45 @@ class OpenAIResponsesImpl:
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
|
|
||||||
async def get_openai_response(
|
async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems:
|
||||||
self,
|
|
||||||
id: str,
|
|
||||||
) -> OpenAIResponseObject:
|
|
||||||
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
|
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
|
||||||
response_json = await self.persistence_store.get(key=key)
|
response_json = await self.persistence_store.get(key=key)
|
||||||
if response_json is None:
|
if response_json is None:
|
||||||
raise ValueError(f"OpenAI response with id '{id}' not found")
|
raise ValueError(f"OpenAI response with id '{id}' not found")
|
||||||
return OpenAIResponseObject.model_validate_json(response_json)
|
return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json)
|
||||||
|
|
||||||
|
async def _prepend_previous_response(
|
||||||
|
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
|
||||||
|
):
|
||||||
|
if previous_response_id:
|
||||||
|
previous_response_with_input = await self._get_previous_response_with_input(previous_response_id)
|
||||||
|
|
||||||
|
# previous response input items
|
||||||
|
new_input_items = previous_response_with_input.input_items.data
|
||||||
|
|
||||||
|
# previous response output items
|
||||||
|
new_input_items.extend(previous_response_with_input.response.output)
|
||||||
|
|
||||||
|
# new input items from the current request
|
||||||
|
if isinstance(input, str):
|
||||||
|
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||||
|
else:
|
||||||
|
new_input_items.extend(input)
|
||||||
|
|
||||||
|
input = new_input_items
|
||||||
|
|
||||||
|
return input
|
||||||
|
|
||||||
|
async def get_openai_response(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
) -> OpenAIResponseObject:
|
||||||
|
response_with_input = await self._get_previous_response_with_input(id)
|
||||||
|
return response_with_input.response
|
||||||
|
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
self,
|
self,
|
||||||
input: str | list[OpenAIResponseInputMessage],
|
input: str | list[OpenAIResponseInput],
|
||||||
model: str,
|
model: str,
|
||||||
previous_response_id: str | None = None,
|
previous_response_id: str | None = None,
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
|
@ -112,31 +227,8 @@ class OpenAIResponsesImpl:
|
||||||
):
|
):
|
||||||
stream = False if stream is None else stream
|
stream = False if stream is None else stream
|
||||||
|
|
||||||
messages: list[OpenAIMessageParam] = []
|
input = await self._prepend_previous_response(input, previous_response_id)
|
||||||
if previous_response_id:
|
messages = await _convert_response_input_to_chat_messages(input)
|
||||||
previous_response = await self.get_openai_response(previous_response_id)
|
|
||||||
messages.extend(await _previous_response_to_messages(previous_response))
|
|
||||||
# TODO: refactor this user_content parsing out into a separate method
|
|
||||||
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
|
|
||||||
if isinstance(input, list):
|
|
||||||
user_content = []
|
|
||||||
for user_input in input:
|
|
||||||
if isinstance(user_input.content, list):
|
|
||||||
for user_input_content in user_input.content:
|
|
||||||
if isinstance(user_input_content, OpenAIResponseInputMessageContentText):
|
|
||||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text))
|
|
||||||
elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage):
|
|
||||||
if user_input_content.image_url:
|
|
||||||
image_url = OpenAIImageURL(
|
|
||||||
url=user_input_content.image_url, detail=user_input_content.detail
|
|
||||||
)
|
|
||||||
user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
|
||||||
else:
|
|
||||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
|
|
||||||
else:
|
|
||||||
user_content = input
|
|
||||||
messages.append(OpenAIUserMessageParam(content=user_content))
|
|
||||||
|
|
||||||
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
||||||
chat_response = await self.inference_api.openai_chat_completion(
|
chat_response = await self.inference_api.openai_chat_completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -150,6 +242,7 @@ class OpenAIResponsesImpl:
|
||||||
# TODO: refactor this into a separate method that handles streaming
|
# TODO: refactor this into a separate method that handles streaming
|
||||||
chat_response_id = ""
|
chat_response_id = ""
|
||||||
chat_response_content = []
|
chat_response_content = []
|
||||||
|
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||||
# TODO: these chunk_ fields are hacky and only take the last chunk into account
|
# TODO: these chunk_ fields are hacky and only take the last chunk into account
|
||||||
chunk_created = 0
|
chunk_created = 0
|
||||||
chunk_model = ""
|
chunk_model = ""
|
||||||
|
@ -163,7 +256,30 @@ class OpenAIResponsesImpl:
|
||||||
chat_response_content.append(chunk_choice.delta.content or "")
|
chat_response_content.append(chunk_choice.delta.content or "")
|
||||||
if chunk_choice.finish_reason:
|
if chunk_choice.finish_reason:
|
||||||
chunk_finish_reason = chunk_choice.finish_reason
|
chunk_finish_reason = chunk_choice.finish_reason
|
||||||
assistant_message = OpenAIAssistantMessageParam(content="".join(chat_response_content))
|
|
||||||
|
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
||||||
|
if chunk_choice.delta.tool_calls:
|
||||||
|
for tool_call in chunk_choice.delta.tool_calls:
|
||||||
|
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||||
|
if response_tool_call:
|
||||||
|
response_tool_call.function.arguments += tool_call.function.arguments
|
||||||
|
else:
|
||||||
|
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||||
|
# Ensure we don't have any empty type field in the tool call dict.
|
||||||
|
# The OpenAI client used by providers often returns a type=None here.
|
||||||
|
tool_call_dict.pop("type", None)
|
||||||
|
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||||
|
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||||
|
|
||||||
|
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
|
||||||
|
if chat_response_tool_calls:
|
||||||
|
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||||
|
else:
|
||||||
|
tool_calls = None
|
||||||
|
assistant_message = OpenAIAssistantMessageParam(
|
||||||
|
content="".join(chat_response_content),
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
chat_response = OpenAIChatCompletion(
|
chat_response = OpenAIChatCompletion(
|
||||||
id=chat_response_id,
|
id=chat_response_id,
|
||||||
choices=[
|
choices=[
|
||||||
|
@ -181,12 +297,26 @@ class OpenAIResponsesImpl:
|
||||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||||
|
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
if chat_response.choices[0].message.tool_calls:
|
for choice in chat_response.choices:
|
||||||
output_messages.extend(
|
if choice.message.tool_calls and tools:
|
||||||
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
|
# Assume if the first tool is a function, all tools are functions
|
||||||
|
if isinstance(tools[0], OpenAIResponseInputToolFunction):
|
||||||
|
for tool_call in choice.message.tool_calls:
|
||||||
|
output_messages.append(
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
arguments=tool_call.function.arguments or "",
|
||||||
|
call_id=tool_call.id,
|
||||||
|
name=tool_call.function.name or "",
|
||||||
|
id=f"fc_{uuid.uuid4()}",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
|
output_messages.extend(
|
||||||
|
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||||
response = OpenAIResponseObject(
|
response = OpenAIResponseObject(
|
||||||
created_at=chat_response.created,
|
created_at=chat_response.created,
|
||||||
id=f"resp-{uuid.uuid4()}",
|
id=f"resp-{uuid.uuid4()}",
|
||||||
|
@ -195,13 +325,43 @@ class OpenAIResponsesImpl:
|
||||||
status="completed",
|
status="completed",
|
||||||
output=output_messages,
|
output=output_messages,
|
||||||
)
|
)
|
||||||
|
logger.debug(f"OpenAI Responses response: {response}")
|
||||||
|
|
||||||
if store:
|
if store:
|
||||||
# Store in kvstore
|
# Store in kvstore
|
||||||
|
|
||||||
|
new_input_id = f"msg_{uuid.uuid4()}"
|
||||||
|
if isinstance(input, str):
|
||||||
|
# synthesize a message from the input string
|
||||||
|
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||||
|
input_content_item = OpenAIResponseMessage(
|
||||||
|
role="user",
|
||||||
|
content=[input_content],
|
||||||
|
id=new_input_id,
|
||||||
|
)
|
||||||
|
input_items_data = [input_content_item]
|
||||||
|
else:
|
||||||
|
# we already have a list of messages
|
||||||
|
input_items_data = []
|
||||||
|
for input_item in input:
|
||||||
|
if isinstance(input_item, OpenAIResponseMessage):
|
||||||
|
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||||
|
input_item_dict = input_item.model_dump()
|
||||||
|
if "id" not in input_item_dict:
|
||||||
|
input_item_dict["id"] = new_input_id
|
||||||
|
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||||
|
else:
|
||||||
|
input_items_data.append(input_item)
|
||||||
|
|
||||||
|
input_items = OpenAIResponseInputItemList(data=input_items_data)
|
||||||
|
prev_response = OpenAIResponsePreviousResponseWithInputItems(
|
||||||
|
input_items=input_items,
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
|
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
|
||||||
await self.persistence_store.set(
|
await self.persistence_store.set(
|
||||||
key=key,
|
key=key,
|
||||||
value=response.model_dump_json(),
|
value=prev_response.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
@ -221,7 +381,9 @@ class OpenAIResponsesImpl:
|
||||||
chat_tools: list[ChatCompletionToolParam] = []
|
chat_tools: list[ChatCompletionToolParam] = []
|
||||||
for input_tool in tools:
|
for input_tool in tools:
|
||||||
# TODO: Handle other tool types
|
# TODO: Handle other tool types
|
||||||
if input_tool.type == "web_search":
|
if input_tool.type == "function":
|
||||||
|
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||||
|
elif input_tool.type == "web_search":
|
||||||
tool_name = "web_search"
|
tool_name = "web_search"
|
||||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||||
tool_def = ToolDefinition(
|
tool_def = ToolDefinition(
|
||||||
|
@ -247,12 +409,11 @@ class OpenAIResponsesImpl:
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
chat_response: OpenAIChatCompletion,
|
choice: OpenAIChoice,
|
||||||
messages: list[OpenAIMessageParam],
|
messages: list[OpenAIMessageParam],
|
||||||
temperature: float,
|
temperature: float,
|
||||||
) -> list[OpenAIResponseOutput]:
|
) -> list[OpenAIResponseOutput]:
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
choice = chat_response.choices[0]
|
|
||||||
|
|
||||||
# If the choice is not an assistant message, we don't need to execute any tools
|
# If the choice is not an assistant message, we don't need to execute any tools
|
||||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||||
|
@ -262,6 +423,9 @@ class OpenAIResponsesImpl:
|
||||||
if not choice.message.tool_calls:
|
if not choice.message.tool_calls:
|
||||||
return output_messages
|
return output_messages
|
||||||
|
|
||||||
|
# Copy the messages list to avoid mutating the original list
|
||||||
|
messages = messages.copy()
|
||||||
|
|
||||||
# Add the assistant message with tool_calls response to the messages list
|
# Add the assistant message with tool_calls response to the messages list
|
||||||
messages.append(choice.message)
|
messages.append(choice.message)
|
||||||
|
|
||||||
|
@ -307,7 +471,9 @@ class OpenAIResponsesImpl:
|
||||||
)
|
)
|
||||||
# type cast to appease mypy
|
# type cast to appease mypy
|
||||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||||
tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices)
|
tool_final_outputs = [
|
||||||
|
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
||||||
|
]
|
||||||
# TODO: Wire in annotations with URLs, titles, etc to these output messages
|
# TODO: Wire in annotations with URLs, titles, etc to these output messages
|
||||||
output_messages.extend(tool_final_outputs)
|
output_messages.extend(tool_final_outputs)
|
||||||
return output_messages
|
return output_messages
|
||||||
|
|
|
@ -9,9 +9,7 @@ import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||||
|
|
||||||
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
|
||||||
from llama_stack.distribution.access_control import check_access
|
from llama_stack.distribution.access_control import check_access
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||||
|
@ -20,15 +18,17 @@ from llama_stack.providers.utils.kvstore import KVStore
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AgentSessionInfo(BaseModel):
|
class AgentSessionInfo(Session):
|
||||||
session_id: str
|
|
||||||
session_name: str
|
|
||||||
# TODO: is this used anywhere?
|
# TODO: is this used anywhere?
|
||||||
vector_db_id: str | None = None
|
vector_db_id: str | None = None
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
access_attributes: AccessAttributes | None = None
|
access_attributes: AccessAttributes | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInfo(AgentConfig):
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
class AgentPersistence:
|
class AgentPersistence:
|
||||||
def __init__(self, agent_id: str, kvstore: KVStore):
|
def __init__(self, agent_id: str, kvstore: KVStore):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
|
@ -46,6 +46,7 @@ class AgentPersistence:
|
||||||
session_name=name,
|
session_name=name,
|
||||||
started_at=datetime.now(timezone.utc),
|
started_at=datetime.now(timezone.utc),
|
||||||
access_attributes=access_attributes,
|
access_attributes=access_attributes,
|
||||||
|
turns=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
|
@ -109,7 +110,7 @@ class AgentPersistence:
|
||||||
if not await self.get_session_if_accessible(session_id):
|
if not await self.get_session_if_accessible(session_id):
|
||||||
raise ValueError(f"Session {session_id} not found or access denied")
|
raise ValueError(f"Session {session_id} not found or access denied")
|
||||||
|
|
||||||
values = await self.kvstore.range(
|
values = await self.kvstore.values_in_range(
|
||||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||||
)
|
)
|
||||||
|
@ -121,7 +122,6 @@ class AgentPersistence:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error parsing turn: {e}")
|
log.error(f"Error parsing turn: {e}")
|
||||||
continue
|
continue
|
||||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||||
|
@ -170,3 +170,43 @@ class AgentPersistence:
|
||||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
)
|
)
|
||||||
return int(value) if value else None
|
return int(value) if value else None
|
||||||
|
|
||||||
|
async def list_sessions(self) -> list[Session]:
|
||||||
|
values = await self.kvstore.values_in_range(
|
||||||
|
start_key=f"session:{self.agent_id}:",
|
||||||
|
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
|
||||||
|
)
|
||||||
|
sessions = []
|
||||||
|
for value in values:
|
||||||
|
try:
|
||||||
|
session_info = Session(**json.loads(value))
|
||||||
|
sessions.append(session_info)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error parsing session info: {e}")
|
||||||
|
continue
|
||||||
|
return sessions
|
||||||
|
|
||||||
|
async def delete_session_turns(self, session_id: str) -> None:
|
||||||
|
"""Delete all turns and their associated data for a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The ID of the session whose turns should be deleted.
|
||||||
|
"""
|
||||||
|
turns = await self.get_session_turns(session_id)
|
||||||
|
for turn in turns:
|
||||||
|
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
|
||||||
|
|
||||||
|
async def delete_session(self, session_id: str) -> None:
|
||||||
|
"""Delete a session and all its associated turns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The ID of the session to delete.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the session does not exist.
|
||||||
|
"""
|
||||||
|
session_info = await self.get_session_info(session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")
|
||||||
|
|
|
@ -11,9 +11,9 @@ from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Dataset
|
from llama_stack.apis.datasets import Dataset
|
||||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||||
from llama_stack.providers.utils.datasetio.pagination import paginate_records
|
|
||||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.providers.utils.pagination import paginate_records
|
||||||
|
|
||||||
from .config import LocalFSDatasetIOConfig
|
from .config import LocalFSDatasetIOConfig
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
# Load existing datasets from kvstore
|
# Load existing datasets from kvstore
|
||||||
start_key = DATASETS_PREFIX
|
start_key = DATASETS_PREFIX
|
||||||
end_key = f"{DATASETS_PREFIX}\xff"
|
end_key = f"{DATASETS_PREFIX}\xff"
|
||||||
stored_datasets = await self.kvstore.range(start_key, end_key)
|
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
||||||
for dataset in stored_datasets:
|
for dataset in stored_datasets:
|
||||||
dataset = Dataset.model_validate_json(dataset)
|
dataset = Dataset.model_validate_json(dataset)
|
||||||
|
|
|
@ -58,7 +58,7 @@ class MetaReferenceEvalImpl(
|
||||||
# Load existing benchmarks from kvstore
|
# Load existing benchmarks from kvstore
|
||||||
start_key = EVAL_TASKS_PREFIX
|
start_key = EVAL_TASKS_PREFIX
|
||||||
end_key = f"{EVAL_TASKS_PREFIX}\xff"
|
end_key = f"{EVAL_TASKS_PREFIX}\xff"
|
||||||
stored_benchmarks = await self.kvstore.range(start_key, end_key)
|
stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
||||||
for benchmark in stored_benchmarks:
|
for benchmark in stored_benchmarks:
|
||||||
benchmark = Benchmark.model_validate_json(benchmark)
|
benchmark = Benchmark.model_validate_json(benchmark)
|
||||||
|
|
|
@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
Inference,
|
InferenceProvider,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
@ -86,7 +86,7 @@ class MetaReferenceInferenceImpl(
|
||||||
OpenAICompletionToLlamaStackMixin,
|
OpenAICompletionToLlamaStackMixin,
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
OpenAIChatCompletionToLlamaStackMixin,
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
Inference,
|
InferenceProvider,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
):
|
):
|
||||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||||
|
|
|
@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
Inference,
|
InferenceProvider,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
@ -38,7 +38,7 @@ class SentenceTransformersInferenceImpl(
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
OpenAIChatCompletionToLlamaStackMixin,
|
||||||
OpenAICompletionToLlamaStackMixin,
|
OpenAICompletionToLlamaStackMixin,
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
Inference,
|
InferenceProvider,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
):
|
):
|
||||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||||
|
|
35
llama_stack/providers/inline/post_training/common/utils.py
Normal file
35
llama_stack/providers/inline/post_training/common/utils.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
|
||||||
|
|
||||||
|
def evacuate_model_from_device(model, device: str):
|
||||||
|
"""Safely clear a model from memory and free device resources.
|
||||||
|
This function handles the proper cleanup of a model by:
|
||||||
|
1. Moving the model to CPU if it's on a non-CPU device
|
||||||
|
2. Deleting the model object to free memory
|
||||||
|
3. Running garbage collection
|
||||||
|
4. Clearing CUDA cache if the model was on a CUDA device
|
||||||
|
Args:
|
||||||
|
model: The PyTorch model to clear
|
||||||
|
device: The device type the model is currently on ('cuda', 'mps', 'cpu')
|
||||||
|
Note:
|
||||||
|
- For CUDA devices, this will clear the CUDA cache after moving the model to CPU
|
||||||
|
- For MPS devices, only moves the model to CPU (no cache clearing available)
|
||||||
|
- For CPU devices, only deletes the model object and runs garbage collection
|
||||||
|
"""
|
||||||
|
if device != "cpu":
|
||||||
|
model.to("cpu")
|
||||||
|
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
# we need to import such that this is only imported when the method is called
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
|
from .config import HuggingFacePostTrainingConfig
|
||||||
|
|
||||||
|
# post_training api and the huggingface provider is still experimental and under heavy development
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: HuggingFacePostTrainingConfig,
|
||||||
|
deps: dict[Api, Any],
|
||||||
|
):
|
||||||
|
from .post_training import HuggingFacePostTrainingImpl
|
||||||
|
|
||||||
|
impl = HuggingFacePostTrainingImpl(
|
||||||
|
config,
|
||||||
|
deps[Api.datasetio],
|
||||||
|
deps[Api.datasets],
|
||||||
|
)
|
||||||
|
return impl
|
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFacePostTrainingConfig(BaseModel):
|
||||||
|
# Device to run training on (cuda, cpu, mps)
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
# Distributed training backend if using multiple devices
|
||||||
|
# fsdp: Fully Sharded Data Parallel
|
||||||
|
# deepspeed: DeepSpeed ZeRO optimization
|
||||||
|
distributed_backend: Literal["fsdp", "deepspeed"] | None = None
|
||||||
|
|
||||||
|
# Format for saving model checkpoints
|
||||||
|
# full_state: Save complete model state
|
||||||
|
# huggingface: Save in HuggingFace format (recommended for compatibility)
|
||||||
|
checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface"
|
||||||
|
|
||||||
|
# Template for formatting chat inputs and outputs
|
||||||
|
# Used to structure the conversation format for training
|
||||||
|
chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}"
|
||||||
|
|
||||||
|
# Model-specific configuration parameters
|
||||||
|
# trust_remote_code: Allow execution of custom model code
|
||||||
|
# attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance
|
||||||
|
model_specific_config: dict = {
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"attn_implementation": "sdpa",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Maximum sequence length for training
|
||||||
|
# Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon)
|
||||||
|
# Longer sequences may cause memory issues on MPS devices
|
||||||
|
max_seq_length: int = 2048
|
||||||
|
|
||||||
|
# Enable gradient checkpointing to reduce memory usage
|
||||||
|
# Trades computation for memory by recomputing activations
|
||||||
|
gradient_checkpointing: bool = False
|
||||||
|
|
||||||
|
# Maximum number of checkpoints to keep
|
||||||
|
# Older checkpoints are deleted when this limit is reached
|
||||||
|
save_total_limit: int = 3
|
||||||
|
|
||||||
|
# Number of training steps between logging updates
|
||||||
|
logging_steps: int = 10
|
||||||
|
|
||||||
|
# Ratio of training steps used for learning rate warmup
|
||||||
|
# Helps stabilize early training
|
||||||
|
warmup_ratio: float = 0.1
|
||||||
|
|
||||||
|
# L2 regularization coefficient
|
||||||
|
# Helps prevent overfitting
|
||||||
|
weight_decay: float = 0.01
|
||||||
|
|
||||||
|
# Number of worker processes for data loading
|
||||||
|
# Higher values can improve data loading speed but increase memory usage
|
||||||
|
dataloader_num_workers: int = 4
|
||||||
|
|
||||||
|
# Whether to pin memory in data loader
|
||||||
|
# Can improve data transfer speed to GPU but uses more memory
|
||||||
|
dataloader_pin_memory: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
|
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}
|
|
@ -0,0 +1,176 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.post_training import (
|
||||||
|
AlgorithmConfig,
|
||||||
|
Checkpoint,
|
||||||
|
DPOAlignmentConfig,
|
||||||
|
JobStatus,
|
||||||
|
ListPostTrainingJobsResponse,
|
||||||
|
PostTrainingJob,
|
||||||
|
PostTrainingJobArtifactsResponse,
|
||||||
|
PostTrainingJobStatusResponse,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.inline.post_training.huggingface.config import (
|
||||||
|
HuggingFacePostTrainingConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||||
|
HFFinetuningSingleDevice,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||||
|
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||||
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingArtifactType(Enum):
|
||||||
|
CHECKPOINT = "checkpoint"
|
||||||
|
RESOURCES_STATS = "resources_stats"
|
||||||
|
|
||||||
|
|
||||||
|
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFacePostTrainingImpl:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: HuggingFacePostTrainingConfig,
|
||||||
|
datasetio_api: DatasetIO,
|
||||||
|
datasets: Datasets,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.datasetio_api = datasetio_api
|
||||||
|
self.datasets_api = datasets
|
||||||
|
self._scheduler = Scheduler()
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
await self._scheduler.shutdown()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
||||||
|
return JobArtifact(
|
||||||
|
type=TrainingArtifactType.CHECKPOINT.value,
|
||||||
|
name=checkpoint.identifier,
|
||||||
|
uri=checkpoint.path,
|
||||||
|
metadata=dict(checkpoint),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
|
||||||
|
return JobArtifact(
|
||||||
|
type=TrainingArtifactType.RESOURCES_STATS.value,
|
||||||
|
name=TrainingArtifactType.RESOURCES_STATS.value,
|
||||||
|
metadata=resources_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def supervised_fine_tune(
|
||||||
|
self,
|
||||||
|
job_uuid: str,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: dict[str, Any],
|
||||||
|
logger_config: dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
checkpoint_dir: str | None = None,
|
||||||
|
algorithm_config: AlgorithmConfig | None = None,
|
||||||
|
) -> PostTrainingJob:
|
||||||
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||||
|
on_log_message_cb("Starting HF finetuning")
|
||||||
|
|
||||||
|
recipe = HFFinetuningSingleDevice(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
datasetio_api=self.datasetio_api,
|
||||||
|
datasets_api=self.datasets_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
resources_allocated, checkpoints = await recipe.train(
|
||||||
|
model=model,
|
||||||
|
output_dir=checkpoint_dir,
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
lora_config=algorithm_config,
|
||||||
|
config=training_config,
|
||||||
|
provider_config=self.config,
|
||||||
|
)
|
||||||
|
|
||||||
|
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||||
|
if checkpoints:
|
||||||
|
for checkpoint in checkpoints:
|
||||||
|
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||||
|
on_artifact_collected_cb(artifact)
|
||||||
|
|
||||||
|
on_status_change_cb(SchedulerJobStatus.completed)
|
||||||
|
on_log_message_cb("HF finetuning completed")
|
||||||
|
|
||||||
|
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||||
|
return PostTrainingJob(job_uuid=job_uuid)
|
||||||
|
|
||||||
|
async def preference_optimize(
|
||||||
|
self,
|
||||||
|
job_uuid: str,
|
||||||
|
finetuned_model: str,
|
||||||
|
algorithm_config: DPOAlignmentConfig,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: dict[str, Any],
|
||||||
|
logger_config: dict[str, Any],
|
||||||
|
) -> PostTrainingJob:
|
||||||
|
raise NotImplementedError("DPO alignment is not implemented yet")
|
||||||
|
|
||||||
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||||
|
return ListPostTrainingJobsResponse(
|
||||||
|
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_artifacts_metadata_by_type(job, artifact_type):
|
||||||
|
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_checkpoints(cls, job):
|
||||||
|
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_resources_allocated(cls, job):
|
||||||
|
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||||
|
return data[0] if data else None
|
||||||
|
|
||||||
|
@webmethod(route="/post-training/job/status")
|
||||||
|
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||||
|
job = self._scheduler.get_job(job_uuid)
|
||||||
|
|
||||||
|
match job.status:
|
||||||
|
# TODO: Add support for other statuses to API
|
||||||
|
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
||||||
|
status = JobStatus.scheduled
|
||||||
|
case SchedulerJobStatus.running:
|
||||||
|
status = JobStatus.in_progress
|
||||||
|
case SchedulerJobStatus.completed:
|
||||||
|
status = JobStatus.completed
|
||||||
|
case SchedulerJobStatus.failed:
|
||||||
|
status = JobStatus.failed
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return PostTrainingJobStatusResponse(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
status=status,
|
||||||
|
scheduled_at=job.scheduled_at,
|
||||||
|
started_at=job.started_at,
|
||||||
|
completed_at=job.completed_at,
|
||||||
|
checkpoints=self._get_checkpoints(job),
|
||||||
|
resources_allocated=self._get_resources_allocated(job),
|
||||||
|
)
|
||||||
|
|
||||||
|
@webmethod(route="/post-training/job/cancel")
|
||||||
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
|
self._scheduler.cancel(job_uuid)
|
||||||
|
|
||||||
|
@webmethod(route="/post-training/job/artifacts")
|
||||||
|
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||||
|
job = self._scheduler.get_job(job_uuid)
|
||||||
|
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue