mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-20 11:47:00 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
51b68b4be6
234 changed files with 21943 additions and 7540 deletions
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
|
@ -2,4 +2,4 @@
|
|||
|
||||
# These owners will be the default owners for everything in
|
||||
# the repo. Unless a later match takes precedence,
|
||||
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan @SLR722 @leseb
|
||||
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning
|
||||
|
|
2
.github/TRIAGERS.md
vendored
2
.github/TRIAGERS.md
vendored
|
@ -1,2 +1,2 @@
|
|||
# This file documents Triage members in the Llama Stack community
|
||||
@franciscojavierarceo @leseb
|
||||
@bbrowning @booxter @franciscojavierarceo @leseb
|
||||
|
|
35
.github/workflows/integration-auth-tests.yml
vendored
35
.github/workflows/integration-auth-tests.yml
vendored
|
@ -28,12 +28,13 @@ jobs:
|
|||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||
with:
|
||||
python-version: "3.10"
|
||||
activate-environment: true
|
||||
|
||||
- name: Set Up Environment and Install Dependencies
|
||||
run: |
|
||||
|
@ -43,7 +44,7 @@ jobs:
|
|||
|
||||
- name: Install minikube
|
||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||
uses: medyagh/setup-minikube@latest
|
||||
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
|
||||
|
||||
- name: Start minikube
|
||||
if: ${{ matrix.auth-provider == 'kubernetes' }}
|
||||
|
@ -74,32 +75,8 @@ jobs:
|
|||
cat <<'EOF' > $run_dir/run.yaml
|
||||
version: '2'
|
||||
image_name: kube
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
||||
apis: []
|
||||
providers: {}
|
||||
server:
|
||||
port: 8321
|
||||
EOF
|
||||
|
|
20
.github/workflows/integration-tests.yml
vendored
20
.github/workflows/integration-tests.yml
vendored
|
@ -33,7 +33,7 @@ jobs:
|
|||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||
with:
|
||||
python-version: "3.10"
|
||||
activate-environment: true
|
||||
|
@ -58,7 +58,7 @@ jobs:
|
|||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 &
|
||||
LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
if: matrix.client-type == 'http'
|
||||
|
@ -85,6 +85,11 @@ jobs:
|
|||
echo "Ollama health check failed"
|
||||
exit 1
|
||||
fi
|
||||
- name: Check Storage and Memory Available Before Tests
|
||||
if: ${{ always() }}
|
||||
run: |
|
||||
free -h
|
||||
df -h
|
||||
|
||||
- name: Run Integration Tests
|
||||
env:
|
||||
|
@ -100,13 +105,20 @@ jobs:
|
|||
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
|
||||
--embedding-model=all-MiniLM-L6-v2
|
||||
|
||||
- name: Check Storage and Memory Available After Tests
|
||||
if: ${{ always() }}
|
||||
run: |
|
||||
free -h
|
||||
df -h
|
||||
|
||||
- name: Write ollama logs to file
|
||||
if: ${{ always() }}
|
||||
run: |
|
||||
sudo journalctl -u ollama.service > ollama.log
|
||||
|
||||
- name: Upload all logs to artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}
|
||||
path: |
|
||||
|
|
18
.github/workflows/providers-build.yml
vendored
18
.github/workflows/providers-build.yml
vendored
|
@ -56,7 +56,7 @@ jobs:
|
|||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
|
@ -94,7 +94,7 @@ jobs:
|
|||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
|
@ -120,7 +120,7 @@ jobs:
|
|||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
|
@ -132,9 +132,9 @@ jobs:
|
|||
|
||||
- name: Build a single provider
|
||||
run: |
|
||||
yq -i '.image_type = "container"' llama_stack/templates/dev/build.yaml
|
||||
yq -i '.image_name = "test"' llama_stack/templates/dev/build.yaml
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/dev/build.yaml
|
||||
yq -i '.image_type = "container"' llama_stack/templates/starter/build.yaml
|
||||
yq -i '.image_name = "test"' llama_stack/templates/starter/build.yaml
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/starter/build.yaml
|
||||
|
||||
- name: Inspect the container image entrypoint
|
||||
run: |
|
||||
|
@ -158,7 +158,7 @@ jobs:
|
|||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
|
@ -174,14 +174,14 @@ jobs:
|
|||
.image_type = "container" |
|
||||
.image_name = "ubi9-test" |
|
||||
.distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest"
|
||||
' llama_stack/templates/dev/build.yaml
|
||||
' llama_stack/templates/starter/build.yaml
|
||||
|
||||
- name: Build dev container (UBI9)
|
||||
env:
|
||||
USE_COPY_NOT_MOUNT: "true"
|
||||
LLAMA_STACK_DIR: "."
|
||||
run: |
|
||||
uv run llama stack build --config llama_stack/templates/dev/build.yaml
|
||||
uv run llama stack build --config llama_stack/templates/starter/build.yaml
|
||||
|
||||
- name: Inspect UBI9 image
|
||||
run: |
|
||||
|
|
11
.github/workflows/test-external-providers.yml
vendored
11
.github/workflows/test-external-providers.yml
vendored
|
@ -23,10 +23,10 @@ jobs:
|
|||
# container and point 'uv pip install' to the correct path...
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
|
@ -47,8 +47,8 @@ jobs:
|
|||
|
||||
- name: Create provider configuration
|
||||
run: |
|
||||
mkdir -p /tmp/providers.d/remote/inference
|
||||
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
||||
mkdir -p /home/runner/.llama/providers.d/remote/inference
|
||||
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml
|
||||
|
||||
- name: Build distro from config file
|
||||
run: |
|
||||
|
@ -66,7 +66,7 @@ jobs:
|
|||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
for i in {1..30}; do
|
||||
if ! grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
||||
if ! grep -q "remote::custom_ollama from /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
||||
echo "Waiting for Llama Stack server to load the provider..."
|
||||
sleep 1
|
||||
else
|
||||
|
@ -75,4 +75,5 @@ jobs:
|
|||
fi
|
||||
done
|
||||
echo "Provider failed to load"
|
||||
cat server.log
|
||||
exit 1
|
||||
|
|
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
@ -37,7 +37,7 @@ jobs:
|
|||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||
- uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
enable-cache: false
|
||||
|
|
9
.github/workflows/update-readthedocs.yml
vendored
9
.github/workflows/update-readthedocs.yml
vendored
|
@ -14,6 +14,8 @@ on:
|
|||
- 'docs/**'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/update-readthedocs.yml'
|
||||
tags:
|
||||
- '*'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
@ -41,7 +43,7 @@ jobs:
|
|||
python-version: '3.11'
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
|
||||
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
|
||||
|
||||
- name: Sync with uv
|
||||
run: uv sync --extra docs
|
||||
|
@ -61,7 +63,10 @@ jobs:
|
|||
|
||||
response=$(curl -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"token\": \"$TOKEN\"}" \
|
||||
-d "{
|
||||
\"token\": \"$TOKEN\",
|
||||
\"version\": \"$GITHUB_REF_NAME\"
|
||||
}" \
|
||||
https://readthedocs.org/api/v2/webhook/llama-stack/289768/)
|
||||
|
||||
echo "Response: $response"
|
||||
|
|
|
@ -106,6 +106,14 @@ repos:
|
|||
pass_filenames: false
|
||||
require_serial: true
|
||||
files: ^llama_stack/apis/|^docs/openapi_generator/
|
||||
- id: check-workflows-use-hashes
|
||||
name: Check GitHub Actions use SHA-pinned actions
|
||||
entry: ./scripts/check-workflows-use-hashes.sh
|
||||
language: system
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
always_run: true
|
||||
files: ^\.github/workflows/.*\.ya?ml$
|
||||
|
||||
ci:
|
||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||
|
|
21
CHANGELOG.md
21
CHANGELOG.md
|
@ -1,5 +1,26 @@
|
|||
# Changelog
|
||||
|
||||
# v0.2.5
|
||||
Published on: 2025-05-04T20:16:49Z
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.4
|
||||
Published on: 2025-04-29T17:26:01Z
|
||||
|
||||
## Highlights
|
||||
|
||||
* One-liner to install and run Llama Stack yay! by @reluctantfuturist in https://github.com/meta-llama/llama-stack/pull/1383
|
||||
* support for NVIDIA NeMo datastore by @raspawar in https://github.com/meta-llama/llama-stack/pull/1852
|
||||
* (yuge!) Kubernetes authentication by @leseb in https://github.com/meta-llama/llama-stack/pull/1778
|
||||
* (yuge!) OpenAI Responses API by @bbrowning in https://github.com/meta-llama/llama-stack/pull/1989
|
||||
* add api.llama provider, llama-guard-4 model by @ashwinb in https://github.com/meta-llama/llama-stack/pull/2058
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.3
|
||||
Published on: 2025-04-25T22:46:21Z
|
||||
|
||||
|
|
|
@ -110,25 +110,9 @@ uv run pre-commit run --all-files
|
|||
> [!CAUTION]
|
||||
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
||||
|
||||
## Running unit tests
|
||||
## Running tests
|
||||
|
||||
You can run the unit tests by running:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
./scripts/unit-tests.sh
|
||||
```
|
||||
|
||||
If you'd like to run for a non-default version of Python (currently 3.10), pass `PYTHON_VERSION` variable as follows:
|
||||
|
||||
```
|
||||
source .venv/bin/activate
|
||||
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
|
||||
```
|
||||
|
||||
## Running integration tests
|
||||
|
||||
You can run integration tests following the instructions [here](tests/integration/README.md).
|
||||
You can find the Llama Stack testing documentation here [here](tests/README.md).
|
||||
|
||||
## Adding a new dependency to the project
|
||||
|
||||
|
@ -153,6 +137,8 @@ uv sync
|
|||
justification for bypassing the check.
|
||||
* When using `# type: ignore` to suppress a mypy warning, include a comment explaining the
|
||||
justification for bypassing the check.
|
||||
* Don't use unicode characters in the codebase. ASCII-only is preferred for compatibility or
|
||||
readability reasons.
|
||||
|
||||
## Common Tasks
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
||||
|
||||
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
|
||||
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
|
||||
|
||||
### ✨🎉 Llama 4 Support 🎉✨
|
||||
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
|
||||
|
|
3954
docs/_static/llama-stack-spec.html
vendored
3954
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
2930
docs/_static/llama-stack-spec.yaml
vendored
2930
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
@ -1050,8 +1050,6 @@
|
|||
"text/html": [
|
||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ToolGroup</span><span style=\"font-weight: bold\">(</span>\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">identifier</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'builtin::code_interpreter'</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">provider_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'code-interpreter'</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">provider_resource_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'builtin::code_interpreter'</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">type</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'tool_group'</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">args</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>,\n",
|
||||
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">mcp_endpoint</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>\n",
|
||||
|
@ -1061,7 +1059,6 @@
|
|||
"text/plain": [
|
||||
"\u001b[1;35mToolGroup\u001b[0m\u001b[1m(\u001b[0m\n",
|
||||
"\u001b[2;32m│ \u001b[0m\u001b[33midentifier\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n",
|
||||
"\u001b[2;32m│ \u001b[0m\u001b[33mprovider_id\u001b[0m=\u001b[32m'code-interpreter'\u001b[0m,\n",
|
||||
"\u001b[2;32m│ \u001b[0m\u001b[33mprovider_resource_id\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n",
|
||||
"\u001b[2;32m│ \u001b[0m\u001b[33mtype\u001b[0m=\u001b[32m'tool_group'\u001b[0m,\n",
|
||||
"\u001b[2;32m│ \u001b[0m\u001b[33margs\u001b[0m=\u001b[3;35mNone\u001b[0m,\n",
|
||||
|
|
|
@ -337,9 +337,6 @@
|
|||
" provider_id: tavily-search\n",
|
||||
" provider_type: remote::tavily-search\n",
|
||||
" - config: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" provider_id: code-interpreter\n",
|
||||
" provider_type: inlin<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e::c</span>ode-interpreter\n",
|
||||
" - config: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
" provider_id: rag-runtime\n",
|
||||
" provider_type: inline::rag-runtime\n",
|
||||
" - config: <span style=\"font-weight: bold\">{}</span>\n",
|
||||
|
@ -378,10 +375,6 @@
|
|||
" toolgroup_id: builtin::rag\n",
|
||||
"- args: null\n",
|
||||
" mcp_endpoint: null\n",
|
||||
" provider_id: code-interpreter\n",
|
||||
" toolgroup_id: builtin::code_interpreter\n",
|
||||
"- args: null\n",
|
||||
" mcp_endpoint: null\n",
|
||||
" provider_id: wolfram-alpha\n",
|
||||
" toolgroup_id: builtin::wolfram_alpha\n",
|
||||
"vector_dbs: <span style=\"font-weight: bold\">[]</span>\n",
|
||||
|
@ -617,9 +610,6 @@
|
|||
" provider_id: tavily-search\n",
|
||||
" provider_type: remote::tavily-search\n",
|
||||
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" provider_id: code-interpreter\n",
|
||||
" provider_type: inlin\u001b[1;92me::c\u001b[0mode-interpreter\n",
|
||||
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
" provider_id: rag-runtime\n",
|
||||
" provider_type: inline::rag-runtime\n",
|
||||
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
|
||||
|
@ -658,10 +648,6 @@
|
|||
" toolgroup_id: builtin::rag\n",
|
||||
"- args: null\n",
|
||||
" mcp_endpoint: null\n",
|
||||
" provider_id: code-interpreter\n",
|
||||
" toolgroup_id: builtin::code_interpreter\n",
|
||||
"- args: null\n",
|
||||
" mcp_endpoint: null\n",
|
||||
" provider_id: wolfram-alpha\n",
|
||||
" toolgroup_id: builtin::wolfram_alpha\n",
|
||||
"vector_dbs: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
|
||||
|
|
|
@ -840,7 +840,6 @@
|
|||
" \"memory_optimizations.rst\",\n",
|
||||
" \"chat.rst\",\n",
|
||||
" \"llama3.rst\",\n",
|
||||
" \"datasets.rst\",\n",
|
||||
" \"qat_finetune.rst\",\n",
|
||||
" \"lora_finetune.rst\",\n",
|
||||
"]\n",
|
||||
|
@ -1586,7 +1585,6 @@
|
|||
" \"memory_optimizations.rst\",\n",
|
||||
" \"chat.rst\",\n",
|
||||
" \"llama3.rst\",\n",
|
||||
" \"datasets.rst\",\n",
|
||||
" \"qat_finetune.rst\",\n",
|
||||
" \"lora_finetune.rst\",\n",
|
||||
"]\n",
|
||||
|
|
|
@ -44,7 +44,7 @@ def main(output_dir: str):
|
|||
if return_type_errors:
|
||||
print("\nAPI Method Return Type Validation Errors:\n")
|
||||
for error in return_type_errors:
|
||||
print(error)
|
||||
print(error, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
now = str(datetime.now())
|
||||
print(
|
||||
|
|
|
@ -759,7 +759,7 @@ class Generator:
|
|||
)
|
||||
|
||||
return Operation(
|
||||
tags=[op.defining_class.__name__],
|
||||
tags=[getattr(op.defining_class, "API_NAMESPACE", op.defining_class.__name__)],
|
||||
summary=None,
|
||||
# summary=doc_string.short_description,
|
||||
description=description,
|
||||
|
@ -805,6 +805,8 @@ class Generator:
|
|||
operation_tags: List[Tag] = []
|
||||
for cls in endpoint_classes:
|
||||
doc_string = parse_type(cls)
|
||||
if hasattr(cls, "API_NAMESPACE") and cls.API_NAMESPACE != cls.__name__:
|
||||
continue
|
||||
operation_tags.append(
|
||||
Tag(
|
||||
name=cls.__name__,
|
||||
|
|
|
@ -174,14 +174,64 @@ def _validate_list_parameters_contain_data(method) -> str | None:
|
|||
return "does not have a mandatory data attribute containing the list of objects"
|
||||
|
||||
|
||||
def _validate_has_ellipsis(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
if "..." not in source and not "NotImplementedError" in source:
|
||||
return "does not contain ellipsis (...) in its implementation"
|
||||
|
||||
def _validate_has_return_in_docstring(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
return_type = method.__annotations__.get('return')
|
||||
if return_type is not None and return_type != type(None) and ":returns:" not in source:
|
||||
return "does not have a ':returns:' in its docstring"
|
||||
|
||||
def _validate_has_params_in_docstring(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
sig = inspect.signature(method)
|
||||
# Only check if the method has more than one parameter
|
||||
if len(sig.parameters) > 1 and ":param" not in source:
|
||||
return "does not have a ':param' in its docstring"
|
||||
|
||||
def _validate_has_no_return_none_in_docstring(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
return_type = method.__annotations__.get('return')
|
||||
if return_type is None and ":returns: None" in source:
|
||||
return "has a ':returns: None' in its docstring which is redundant for None-returning functions"
|
||||
|
||||
def _validate_docstring_lines_end_with_dot(method) -> str | None:
|
||||
docstring = inspect.getdoc(method)
|
||||
if docstring is None:
|
||||
return None
|
||||
|
||||
lines = docstring.split('\n')
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and not any(line.endswith(char) for char in '.:{}[]()",'):
|
||||
return f"docstring line '{line}' does not end with a valid character: . : {{ }} [ ] ( ) , \""
|
||||
|
||||
_VALIDATORS = {
|
||||
"GET": [
|
||||
_validate_api_method_return_type,
|
||||
_validate_list_parameters_contain_data,
|
||||
_validate_api_method_doesnt_return_list,
|
||||
_validate_has_ellipsis,
|
||||
_validate_has_return_in_docstring,
|
||||
_validate_has_params_in_docstring,
|
||||
_validate_docstring_lines_end_with_dot,
|
||||
],
|
||||
"DELETE": [
|
||||
_validate_api_delete_method_returns_none,
|
||||
_validate_has_ellipsis,
|
||||
_validate_has_return_in_docstring,
|
||||
_validate_has_params_in_docstring,
|
||||
_validate_has_no_return_none_in_docstring
|
||||
],
|
||||
"POST": [
|
||||
_validate_has_ellipsis,
|
||||
_validate_has_return_in_docstring,
|
||||
_validate_has_params_in_docstring,
|
||||
_validate_has_no_return_none_in_docstring,
|
||||
_validate_docstring_lines_end_with_dot,
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ chunks = [
|
|||
"mime_type": "text/plain",
|
||||
"metadata": {
|
||||
"document_id": "doc1",
|
||||
"author": "Jane Doe",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
@ -98,6 +99,17 @@ results = client.tool_runtime.rag_tool.query(
|
|||
)
|
||||
```
|
||||
|
||||
You can configure how the RAG tool adds metadata to the context if you find it useful for your application. Simply add:
|
||||
```python
|
||||
# Query documents
|
||||
results = client.tool_runtime.rag_tool.query(
|
||||
vector_db_ids=[vector_db_id],
|
||||
content="What do you know about...",
|
||||
query_config={
|
||||
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||
},
|
||||
)
|
||||
```
|
||||
### Building RAG-Enhanced Agents
|
||||
|
||||
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
|
||||
|
@ -115,6 +127,12 @@ agent = Agent(
|
|||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {
|
||||
"vector_db_ids": [vector_db_id],
|
||||
# Defaults
|
||||
"query_config": {
|
||||
"chunk_size_in_tokens": 512,
|
||||
"chunk_overlap_in_tokens": 0,
|
||||
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
|
|
|
@ -165,34 +165,6 @@ all_tools = client.tools.list_tools()
|
|||
group_tools = client.tools.list_tools(toolgroup_id="search_tools")
|
||||
```
|
||||
|
||||
## Simple Example: Using an Agent with the Code-Interpreter Tool
|
||||
|
||||
```python
|
||||
from llama_stack_client import Agent
|
||||
|
||||
# Instantiate the AI agent with the given configuration
|
||||
agent = Agent(
|
||||
client,
|
||||
name="code-interpreter",
|
||||
description="A code interpreter agent for executing Python code snippets",
|
||||
instructions="""
|
||||
You are a highly reliable, concise, and precise assistant.
|
||||
Always show the generated code, never generate your own code, and never anticipate results.
|
||||
""",
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
tools=["builtin::code_interpreter"],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
# Start a session
|
||||
session_id = agent.create_session("tool_session")
|
||||
|
||||
# Send a query to the AI agent for code execution
|
||||
response = agent.create_turn(
|
||||
messages=[{"role": "user", "content": "Run this code: print(3 ** 4 - 5 * 2)"}],
|
||||
session_id=session_id,
|
||||
)
|
||||
```
|
||||
## Simple Example 2: Using an Agent with the Web Search Tool
|
||||
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
|
||||
2. [Optional] Provide the API key directly to the Llama Stack server
|
||||
|
|
|
@ -110,6 +110,8 @@ html_theme_options = {
|
|||
"canonical_url": "https://github.com/meta-llama/llama-stack",
|
||||
"collapse_navigation": False,
|
||||
# "style_nav_header_background": "#c3c9d4",
|
||||
'display_version': True,
|
||||
'version_selector': True,
|
||||
}
|
||||
|
||||
default_dark_mode = False
|
||||
|
|
|
@ -6,7 +6,7 @@ This guide will walk you through the process of adding a new API provider to Lla
|
|||
- Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.)
|
||||
- Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally.
|
||||
- Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary.
|
||||
- Update any distribution {repopath}`Templates::llama_stack/templates/` build.yaml and run.yaml files if they should include your provider by default. Run {repopath}`./scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation.
|
||||
- Update any distribution {repopath}`Templates::llama_stack/templates/` `build.yaml` and `run.yaml` files if they should include your provider by default. Run {repopath}`./scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation.
|
||||
|
||||
|
||||
Here are some example PRs to help you get started:
|
||||
|
@ -33,6 +33,7 @@ Note that each provider's `sample_run_config()` method (in the configuration cla
|
|||
|
||||
Unit tests are located in {repopath}`tests/unit`. Provider-specific unit tests are located in {repopath}`tests/unit/providers`. These tests are all run automatically as part of the CI process.
|
||||
|
||||
Consult {repopath}`tests/unit/README.md` for more details on how to run the tests manually.
|
||||
|
||||
### 3. Additional end-to-end testing
|
||||
|
||||
|
|
|
@ -178,7 +178,7 @@ image_name: ollama
|
|||
image_type: conda
|
||||
|
||||
# If some providers are external, you can specify the path to the implementation
|
||||
external_providers_dir: /etc/llama-stack/providers.d
|
||||
external_providers_dir: ~/.llama/providers.d
|
||||
```
|
||||
|
||||
```
|
||||
|
@ -206,7 +206,7 @@ distribution_spec:
|
|||
image_type: container
|
||||
image_name: ci-test
|
||||
# Path to external provider implementations
|
||||
external_providers_dir: /etc/llama-stack/providers.d
|
||||
external_providers_dir: ~/.llama/providers.d
|
||||
```
|
||||
|
||||
Here's an example for a custom Ollama provider:
|
||||
|
@ -271,7 +271,7 @@ Now, let's start the Llama Stack Distribution Server. You will need the YAML con
|
|||
|
||||
```
|
||||
llama stack run -h
|
||||
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] [--tls-certfile TLS_CERTFILE]
|
||||
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] [--tls-certfile TLS_CERTFILE]
|
||||
[--image-type {conda,container,venv}]
|
||||
config
|
||||
|
||||
|
@ -285,7 +285,6 @@ options:
|
|||
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
|
||||
--image-name IMAGE_NAME
|
||||
Name of the image to run. Defaults to the current environment (default: None)
|
||||
--disable-ipv6 Disable IPv6 support (default: False)
|
||||
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
|
||||
--tls-keyfile TLS_KEYFILE
|
||||
Path to TLS key file for HTTPS (default: None)
|
||||
|
|
|
@ -172,7 +172,7 @@ spec:
|
|||
- name: llama-stack
|
||||
image: localhost/llama-stack-run-k8s:latest
|
||||
imagePullPolicy: IfNotPresent
|
||||
command: ["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"]
|
||||
command: ["python", "-m", "llama_stack.distribution.server.server", "--config", "/app/config.yaml"]
|
||||
ports:
|
||||
- containerPort: 5000
|
||||
volumeMounts:
|
||||
|
|
|
@ -18,7 +18,7 @@ The `llamastack/distribution-watsonx` distribution consists of the following pro
|
|||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| inference | `remote::watsonx` |
|
||||
| inference | `remote::watsonx`, `inline::sentence-transformers` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
|
@ -70,7 +70,7 @@ docker run \
|
|||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-watsonx \
|
||||
--yaml-config /root/my-run.yaml \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||
|
|
|
@ -52,7 +52,7 @@ docker run \
|
|||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-cerebras \
|
||||
--yaml-config /root/my-run.yaml \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||
```
|
||||
|
|
|
@ -23,7 +23,7 @@ The `llamastack/distribution-dell` distribution consists of the following provid
|
|||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
|
@ -155,7 +155,7 @@ docker run \
|
|||
-v $HOME/.llama:/root/.llama \
|
||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-dell \
|
||||
--yaml-config /root/my-run.yaml \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env DEH_URL=$DEH_URL \
|
||||
|
|
|
@ -144,7 +144,7 @@ docker run \
|
|||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-nvidia \
|
||||
--yaml-config /root/my-run.yaml \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
```
|
||||
|
|
|
@ -19,6 +19,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
|
|||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| inference | `remote::ollama` |
|
||||
| post_training | `inline::huggingface` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
|
@ -97,7 +98,7 @@ docker run \
|
|||
-v ~/.llama:/root/.llama \
|
||||
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-ollama \
|
||||
--yaml-config /root/my-run.yaml \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env SAFETY_MODEL=$SAFETY_MODEL \
|
||||
|
|
|
@ -233,7 +233,7 @@ docker run \
|
|||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-remote-vllm \
|
||||
--yaml-config /root/my-run.yaml \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1
|
||||
|
@ -255,7 +255,7 @@ docker run \
|
|||
-v ~/.llama:/root/.llama \
|
||||
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-remote-vllm \
|
||||
--yaml-config /root/my-run.yaml \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \
|
||||
|
|
|
@ -16,10 +16,10 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
|||
| API | Provider(s) |
|
||||
|-----|-------------|
|
||||
| agents | `inline::meta-reference` |
|
||||
| inference | `remote::sambanova` |
|
||||
| inference | `remote::sambanova`, `inline::sentence-transformers` |
|
||||
| safety | `inline::llama-guard` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
|
||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
|
||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||
|
||||
|
||||
|
@ -28,22 +28,22 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|
|||
The following environment variables can be configured:
|
||||
|
||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
||||
- `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``)
|
||||
- `SAMBANOVA_API_KEY`: SambaNova API Key (default: ``)
|
||||
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
- `Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||
- `Meta-Llama-3.1-70B-Instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||
- `Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||
- `Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||
- `Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||
- `Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||
- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||
- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||
- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||
- `Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||
- `sambanova/Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||
- `sambanova/Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||
- `sambanova/Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||
- `sambanova/Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||
- `sambanova/Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||
- `sambanova/Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||
- `sambanova/Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||
- `sambanova/Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||
- `sambanova/Llama-4-Maverick-17B-128E-Instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
|
||||
- `sambanova/Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
|
|
@ -117,7 +117,7 @@ docker run \
|
|||
-v ~/.llama:/root/.llama \
|
||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-tgi \
|
||||
--yaml-config /root/my-run.yaml \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \
|
||||
|
|
|
@ -42,7 +42,7 @@ powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | ie
|
|||
Setup your virtual environment.
|
||||
|
||||
```bash
|
||||
uv venv --python 3.10
|
||||
uv sync --python 3.10
|
||||
source .venv/bin/activate
|
||||
```
|
||||
## Step 2: Run Llama Stack
|
||||
|
@ -445,7 +445,6 @@ from llama_stack_client import LlamaStackClient
|
|||
from llama_stack_client import Agent, AgentEventLogger
|
||||
from llama_stack_client.types import Document
|
||||
import uuid
|
||||
from termcolor import cprint
|
||||
|
||||
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||
|
||||
|
@ -463,7 +462,6 @@ urls = [
|
|||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
|
|
|
@ -10,7 +10,7 @@ Llama Stack supports external providers that live outside of the main codebase.
|
|||
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
|
||||
|
||||
```yaml
|
||||
external_providers_dir: /etc/llama-stack/providers.d/
|
||||
external_providers_dir: ~/.llama/providers.d/
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
@ -53,7 +53,7 @@ Here's a list of known external providers that you can use with Llama Stack:
|
|||
| Name | Description | API | Type | Repository |
|
||||
|------|-------------|-----|------|------------|
|
||||
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||
| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) |
|
||||
| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Inline **and** Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) |
|
||||
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
|
||||
| TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) |
|
||||
|
||||
|
@ -182,7 +182,7 @@ dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
|
|||
3. Create the provider specification:
|
||||
|
||||
```yaml
|
||||
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml
|
||||
# ~/.llama/providers.d/remote/inference/custom_ollama.yaml
|
||||
adapter:
|
||||
adapter_type: custom_ollama
|
||||
pip_packages: ["ollama", "aiohttp"]
|
||||
|
@ -201,7 +201,7 @@ uv pip install -e .
|
|||
5. Configure Llama Stack to use external providers:
|
||||
|
||||
```yaml
|
||||
external_providers_dir: /etc/llama-stack/providers.d/
|
||||
external_providers_dir: ~/.llama/providers.d/
|
||||
```
|
||||
|
||||
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
|
||||
|
|
|
@ -253,8 +253,6 @@ llama-stack-client toolgroups list
|
|||
+---------------------------+------------------+------+---------------+
|
||||
| identifier | provider_id | args | mcp_endpoint |
|
||||
+===========================+==================+======+===============+
|
||||
| builtin::code_interpreter | code-interpreter | None | None |
|
||||
+---------------------------+------------------+------+---------------+
|
||||
| builtin::rag | rag-runtime | None | None |
|
||||
+---------------------------+------------------+------+---------------+
|
||||
| builtin::websearch | tavily-search | None | None |
|
||||
|
|
61
install.sh
61
install.sh
|
@ -38,6 +38,67 @@ wait_for_service() {
|
|||
return 0
|
||||
}
|
||||
|
||||
usage() {
|
||||
cat << EOF
|
||||
📚 Llama-Stack Deployment Script
|
||||
|
||||
Description:
|
||||
This script sets up and deploys Llama-Stack with Ollama integration in containers.
|
||||
It handles both Docker and Podman runtimes and includes automatic platform detection.
|
||||
|
||||
Usage:
|
||||
$(basename "$0") [OPTIONS]
|
||||
|
||||
Options:
|
||||
-p, --port PORT Server port for Llama-Stack (default: ${PORT})
|
||||
-o, --ollama-port PORT Ollama service port (default: ${OLLAMA_PORT})
|
||||
-m, --model MODEL Model alias to use (default: ${MODEL_ALIAS})
|
||||
-i, --image IMAGE Server image (default: ${SERVER_IMAGE})
|
||||
-t, --timeout SECONDS Service wait timeout in seconds (default: ${WAIT_TIMEOUT})
|
||||
-h, --help Show this help message
|
||||
|
||||
For more information:
|
||||
Documentation: https://llama-stack.readthedocs.io/
|
||||
GitHub: https://github.com/meta-llama/llama-stack
|
||||
|
||||
Report issues:
|
||||
https://github.com/meta-llama/llama-stack/issues
|
||||
EOF
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-h|--help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
-p|--port)
|
||||
PORT="$2"
|
||||
shift 2
|
||||
;;
|
||||
-o|--ollama-port)
|
||||
OLLAMA_PORT="$2"
|
||||
shift 2
|
||||
;;
|
||||
-m|--model)
|
||||
MODEL_ALIAS="$2"
|
||||
shift 2
|
||||
;;
|
||||
-i|--image)
|
||||
SERVER_IMAGE="$2"
|
||||
shift 2
|
||||
;;
|
||||
-t|--timeout)
|
||||
WAIT_TIMEOUT="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
die "Unknown option: $1"
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if command -v docker &> /dev/null; then
|
||||
ENGINE="docker"
|
||||
elif command -v podman &> /dev/null; then
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
@ -12,6 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
ResponseFormat,
|
||||
|
@ -29,12 +31,20 @@ from llama_stack.apis.tools import ToolDef
|
|||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
)
|
||||
|
||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""Backport of StrEnum for Python 3.10 and below."""
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
@ -73,7 +83,7 @@ class StepCommon(BaseModel):
|
|||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
class StepType(Enum):
|
||||
class StepType(StrEnum):
|
||||
"""Type of the step in an agent turn.
|
||||
|
||||
:cvar inference: The step is an inference step that calls an LLM.
|
||||
|
@ -97,7 +107,7 @@ class InferenceStep(StepCommon):
|
|||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||
step_type: Literal[StepType.inference] = StepType.inference
|
||||
model_response: CompletionMessage
|
||||
|
||||
|
||||
|
@ -109,7 +119,7 @@ class ToolExecutionStep(StepCommon):
|
|||
:param tool_responses: The tool responses from the tool calls.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
|
||||
tool_calls: list[ToolCall]
|
||||
tool_responses: list[ToolResponse]
|
||||
|
||||
|
@ -121,7 +131,7 @@ class ShieldCallStep(StepCommon):
|
|||
:param violation: The violation from the shield call.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||
step_type: Literal[StepType.shield_call] = StepType.shield_call
|
||||
violation: SafetyViolation | None
|
||||
|
||||
|
||||
|
@ -133,7 +143,7 @@ class MemoryRetrievalStep(StepCommon):
|
|||
:param inserted_context: The context retrieved from the vector databases.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
|
||||
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
|
||||
# TODO: should this be List[str]?
|
||||
vector_db_ids: str
|
||||
inserted_context: InterleavedContent
|
||||
|
@ -154,7 +164,7 @@ class Turn(BaseModel):
|
|||
input_messages: list[UserMessage | ToolResponseMessage]
|
||||
steps: list[Step]
|
||||
output_message: CompletionMessage
|
||||
output_attachments: list[Attachment] | None = Field(default_factory=list)
|
||||
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
|
||||
|
||||
started_at: datetime
|
||||
completed_at: datetime | None = None
|
||||
|
@ -182,10 +192,10 @@ register_schema(AgentToolGroup, name="AgentTool")
|
|||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
input_shields: list[str] | None = Field(default_factory=list)
|
||||
output_shields: list[str] | None = Field(default_factory=list)
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=list)
|
||||
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_config: ToolConfig | None = Field(default=None)
|
||||
|
@ -232,21 +242,11 @@ class Agent(BaseModel):
|
|||
created_at: datetime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentsResponse(BaseModel):
|
||||
data: list[Agent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentSessionsResponse(BaseModel):
|
||||
data: list[Session]
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: str | None = None
|
||||
|
||||
|
||||
class AgentTurnResponseEventType(Enum):
|
||||
class AgentTurnResponseEventType(StrEnum):
|
||||
step_start = "step_start"
|
||||
step_complete = "step_complete"
|
||||
step_progress = "step_progress"
|
||||
|
@ -258,15 +258,15 @@ class AgentTurnResponseEventType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
||||
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
metadata: dict[str, Any] | None = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value
|
||||
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
step_details: Step
|
||||
|
@ -276,7 +276,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
|
|||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value
|
||||
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
|
@ -285,21 +285,19 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
|
||||
turn_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
|
||||
turn: Turn
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
|
||||
AgentTurnResponseEventType.turn_awaiting_input.value
|
||||
)
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
|
||||
turn: Turn
|
||||
|
||||
|
||||
|
@ -341,7 +339,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
messages: list[UserMessage | ToolResponseMessage]
|
||||
|
||||
documents: list[Document] | None = None
|
||||
toolgroups: list[AgentToolGroup] | None = None
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
|
||||
stream: bool | None = False
|
||||
tool_config: ToolConfig | None = None
|
||||
|
@ -415,8 +413,9 @@ class Agents(Protocol):
|
|||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
||||
:returns: If stream=False, returns a Turn object.
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
|
@ -510,6 +509,7 @@ class Agents(Protocol):
|
|||
:param session_id: The ID of the session to get.
|
||||
:param agent_id: The ID of the agent to get the session for.
|
||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
||||
:returns: A Session.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -519,7 +519,7 @@ class Agents(Protocol):
|
|||
session_id: str,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Delete an agent session by its ID.
|
||||
"""Delete an agent session by its ID and its associated turns.
|
||||
|
||||
:param session_id: The ID of the session to delete.
|
||||
:param agent_id: The ID of the agent to delete the session for.
|
||||
|
@ -531,17 +531,19 @@ class Agents(Protocol):
|
|||
self,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Delete an agent by its ID.
|
||||
"""Delete an agent by its ID and its associated sessions and turns.
|
||||
|
||||
:param agent_id: The ID of the agent to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET")
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||
"""List all agents.
|
||||
|
||||
:returns: A ListAgentsResponse.
|
||||
:param start_index: The index to start the pagination from.
|
||||
:param limit: The number of agents to return.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -558,11 +560,15 @@ class Agents(Protocol):
|
|||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""List all session(s) of a given agent.
|
||||
|
||||
:param agent_id: The ID of the agent to list sessions for.
|
||||
:returns: A ListAgentSessionsResponse.
|
||||
:param start_index: The index to start the pagination from.
|
||||
:param limit: The number of sessions to return.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -588,7 +594,7 @@ class Agents(Protocol):
|
|||
@webmethod(route="/openai/v1/responses", method="POST")
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
@ -601,4 +607,6 @@ class Agents(Protocol):
|
|||
:param input: Input message(s) to create the response.
|
||||
:param model: The underlying LLM used for completions.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Literal
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -17,6 +17,28 @@ class OpenAIResponseError(BaseModel):
|
|||
message: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentText(BaseModel):
|
||||
text: str
|
||||
type: Literal["input_text"] = "input_text"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||
type: Literal["input_image"] = "input_image"
|
||||
# TODO: handle file_id
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
# TODO: handle file content types
|
||||
OpenAIResponseInputMessageContent = Annotated[
|
||||
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
||||
text: str
|
||||
|
@ -31,13 +53,22 @@ register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMe
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessage(BaseModel):
|
||||
id: str
|
||||
content: list[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["assistant"] = "assistant"
|
||||
status: str
|
||||
class OpenAIResponseMessage(BaseModel):
|
||||
"""
|
||||
Corresponds to the various Message types in the Responses API.
|
||||
They are all under one type because the Responses API gives them all
|
||||
the same "type" value, and there is no way to tell them apart in certain
|
||||
scenarios.
|
||||
"""
|
||||
|
||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
type: Literal["message"] = "message"
|
||||
|
||||
# The fields below are not used in all scenarios, but are required in others.
|
||||
id: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||
|
@ -46,8 +77,18 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
|||
type: Literal["web_search_call"] = "web_search_call"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||
arguments: str
|
||||
call_id: str
|
||||
name: str
|
||||
type: Literal["function_call"] = "function_call"
|
||||
id: str
|
||||
status: str
|
||||
|
||||
|
||||
OpenAIResponseOutput = Annotated[
|
||||
OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||
|
@ -90,32 +131,29 @@ register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentText(BaseModel):
|
||||
text: str
|
||||
type: Literal["input_text"] = "input_text"
|
||||
class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
|
||||
"""
|
||||
This represents the output of a function call that gets passed back to the model.
|
||||
"""
|
||||
|
||||
call_id: str
|
||||
output: str
|
||||
type: Literal["function_call_output"] = "function_call_output"
|
||||
id: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||
type: Literal["input_image"] = "input_image"
|
||||
# TODO: handle file_id
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
# TODO: handle file content types
|
||||
OpenAIResponseInputMessageContent = Annotated[
|
||||
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
|
||||
Field(discriminator="type"),
|
||||
OpenAIResponseInput = Annotated[
|
||||
# Responses API allows output messages to be passed in as input
|
||||
OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseInputFunctionToolCallOutput
|
||||
|
|
||||
# Fallback to the generic message type as a last resort
|
||||
OpenAIResponseMessage,
|
||||
Field(union_mode="left_to_right"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessage(BaseModel):
|
||||
content: str | list[OpenAIResponseInputMessageContent]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
type: Literal["message"] | None = "message"
|
||||
register_schema(OpenAIResponseInput, name="OpenAIResponseInput")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -126,8 +164,35 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
|
|||
# TODO: add user_location
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFunction(BaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: dict[str, Any] | None
|
||||
strict: bool | None = None
|
||||
|
||||
|
||||
class FileSearchRankingOptions(BaseModel):
|
||||
ranker: str | None = None
|
||||
score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolFileSearch(BaseModel):
|
||||
type: Literal["file_search"] = "file_search"
|
||||
vector_store_id: list[str]
|
||||
ranking_options: FileSearchRankingOptions | None = None
|
||||
# TODO: add filters
|
||||
|
||||
|
||||
OpenAIResponseInputTool = Annotated[
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
||||
|
||||
|
||||
class OpenAIResponseInputItemList(BaseModel):
|
||||
data: list[OpenAIResponseInput]
|
||||
object: Literal["list"] = "list"
|
||||
|
|
|
@ -38,7 +38,17 @@ class BatchInference(Protocol):
|
|||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> Job: ...
|
||||
) -> Job:
|
||||
"""Generate completions for a batch of content.
|
||||
|
||||
:param model: The model to use for the completion.
|
||||
:param content_batch: The content to complete.
|
||||
:param sampling_params: The sampling parameters to use for the completion.
|
||||
:param response_format: The response format to use for the completion.
|
||||
:param logprobs: The logprobs to use for the completion.
|
||||
:returns: A job for the completion.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||
async def chat_completion(
|
||||
|
@ -52,4 +62,17 @@ class BatchInference(Protocol):
|
|||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> Job: ...
|
||||
) -> Job:
|
||||
"""Generate chat completions for a batch of messages.
|
||||
|
||||
:param model: The model to use for the chat completion.
|
||||
:param messages_batch: The messages to complete.
|
||||
:param sampling_params: The sampling parameters to use for the completion.
|
||||
:param tools: The tools to use for the chat completion.
|
||||
:param tool_choice: The tool choice to use for the chat completion.
|
||||
:param tool_prompt_format: The tool prompt format to use for the chat completion.
|
||||
:param response_format: The response format to use for the chat completion.
|
||||
:param logprobs: The logprobs to use for the chat completion.
|
||||
:returns: A job for the chat completion.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -22,14 +22,14 @@ class CommonBenchmarkFields(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Benchmark(CommonBenchmarkFields, Resource):
|
||||
type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value
|
||||
type: Literal[ResourceType.benchmark] = ResourceType.benchmark
|
||||
|
||||
@property
|
||||
def benchmark_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_benchmark_id(self) -> str:
|
||||
def provider_benchmark_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
|
@ -46,13 +46,24 @@ class ListBenchmarksResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
class Benchmarks(Protocol):
|
||||
@webmethod(route="/eval/benchmarks", method="GET")
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||
"""List all benchmarks.
|
||||
|
||||
:returns: A ListBenchmarksResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
|
||||
async def get_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
) -> Benchmark: ...
|
||||
) -> Benchmark:
|
||||
"""Get a benchmark by its ID.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to get.
|
||||
:returns: A Benchmark.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST")
|
||||
async def register_benchmark(
|
||||
|
@ -63,4 +74,14 @@ class Benchmarks(Protocol):
|
|||
provider_benchmark_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Register a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to register.
|
||||
:param dataset_id: The ID of the dataset to use for the benchmark.
|
||||
:param scoring_functions: The scoring functions to use for the benchmark.
|
||||
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
|
||||
:param provider_id: The ID of the provider to use for the benchmark.
|
||||
:param metadata: The metadata to use for the benchmark.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -28,7 +28,7 @@ class _URLOrData(BaseModel):
|
|||
|
||||
url: URL | None = None
|
||||
# data is a base64 encoded string, hint with contentEncoding=base64
|
||||
data: str | None = Field(contentEncoding="base64", default=None)
|
||||
data: str | None = Field(default=None, json_schema_extra={"contentEncoding": "base64"})
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
|
|
@ -34,14 +34,21 @@ class DatasetIO(Protocol):
|
|||
- limit: Number of items to return. If None or -1, returns all items.
|
||||
|
||||
The response includes:
|
||||
- data: List of items for the current page
|
||||
- has_more: Whether there are more items available after this set
|
||||
- data: List of items for the current page.
|
||||
- has_more: Whether there are more items available after this set.
|
||||
|
||||
:param dataset_id: The ID of the dataset to get the rows from.
|
||||
:param start_index: Index into dataset for the first row to get. Get all rows if None.
|
||||
:param limit: The number of rows to get.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ...
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
"""Append rows to a dataset.
|
||||
|
||||
:param dataset_id: The ID of the dataset to append the rows to.
|
||||
:param rows: The rows to append to the dataset.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -106,14 +106,14 @@ class CommonDatasetFields(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Dataset(CommonDatasetFields, Resource):
|
||||
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value
|
||||
type: Literal[ResourceType.dataset] = ResourceType.dataset
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_dataset_id(self) -> str:
|
||||
def provider_dataset_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
|
@ -137,7 +137,8 @@ class Datasets(Protocol):
|
|||
"""
|
||||
Register a new dataset.
|
||||
|
||||
:param purpose: The purpose of the dataset. One of
|
||||
:param purpose: The purpose of the dataset.
|
||||
One of:
|
||||
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
||||
{
|
||||
"messages": [
|
||||
|
@ -188,8 +189,9 @@ class Datasets(Protocol):
|
|||
]
|
||||
}
|
||||
:param metadata: The metadata for the dataset.
|
||||
- E.g. {"description": "My dataset"}
|
||||
- E.g. {"description": "My dataset"}.
|
||||
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
||||
:returns: A Dataset.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -197,13 +199,29 @@ class Datasets(Protocol):
|
|||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> Dataset: ...
|
||||
) -> Dataset:
|
||||
"""Get a dataset by its ID.
|
||||
|
||||
:param dataset_id: The ID of the dataset to get.
|
||||
:returns: A Dataset.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets", method="GET")
|
||||
async def list_datasets(self) -> ListDatasetsResponse: ...
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
"""List all datasets.
|
||||
|
||||
:returns: A ListDatasetsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Unregister a dataset by its ID.
|
||||
|
||||
:param dataset_id: The ID of the dataset to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -93,8 +93,9 @@ class Eval(Protocol):
|
|||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: The job that was created to run the evaluation.
|
||||
:returns: The job that was created to run the evaluation.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||
async def evaluate_rows(
|
||||
|
@ -110,8 +111,9 @@ class Eval(Protocol):
|
|||
:param input_rows: The rows to evaluate.
|
||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: EvaluateResponse object containing generations and scores
|
||||
:returns: EvaluateResponse object containing generations and scores.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
|
@ -119,7 +121,7 @@ class Eval(Protocol):
|
|||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the status of.
|
||||
:return: The status of the evaluationjob.
|
||||
:returns: The status of the evaluation job.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -138,5 +140,6 @@ class Eval(Protocol):
|
|||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the result of.
|
||||
:return: The result of the job.
|
||||
:returns: The result of the job.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -91,10 +91,11 @@ class Files(Protocol):
|
|||
"""
|
||||
Create a new upload session for a file identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
:param mime_type: MIME type of the file
|
||||
:param size: File size in bytes
|
||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-).
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||
:param mime_type: MIME type of the file.
|
||||
:param size: File size in bytes.
|
||||
:returns: A FileUploadResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -107,7 +108,8 @@ class Files(Protocol):
|
|||
Upload file content to an existing upload session.
|
||||
On the server, request body will have the raw bytes that are uploaded.
|
||||
|
||||
:param upload_id: ID of the upload session
|
||||
:param upload_id: ID of the upload session.
|
||||
:returns: A FileResponse or None if the upload is not complete.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -117,9 +119,10 @@ class Files(Protocol):
|
|||
upload_id: str,
|
||||
) -> FileUploadResponse:
|
||||
"""
|
||||
Returns information about an existsing upload session
|
||||
Returns information about an existsing upload session.
|
||||
|
||||
:param upload_id: ID of the upload session
|
||||
:param upload_id: ID of the upload session.
|
||||
:returns: A FileUploadResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -130,6 +133,9 @@ class Files(Protocol):
|
|||
) -> ListBucketResponse:
|
||||
"""
|
||||
List all buckets.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:returns: A ListBucketResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -141,7 +147,8 @@ class Files(Protocol):
|
|||
"""
|
||||
List all files in a bucket.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:returns: A ListFileResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -154,8 +161,9 @@ class Files(Protocol):
|
|||
"""
|
||||
Get a file info identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||
:returns: A FileResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -168,7 +176,7 @@ class Files(Protocol):
|
|||
"""
|
||||
Delete a file identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
|
@ -35,6 +36,16 @@ register_schema(ToolCall)
|
|||
register_schema(ToolParamDefinition)
|
||||
register_schema(ToolDefinition)
|
||||
|
||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""Backport of StrEnum for Python 3.10 and below."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
|
@ -187,7 +198,7 @@ class CompletionMessage(BaseModel):
|
|||
role: Literal["assistant"] = "assistant"
|
||||
content: InterleavedContent
|
||||
stop_reason: StopReason
|
||||
tool_calls: list[ToolCall] | None = Field(default_factory=list)
|
||||
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
|
||||
|
||||
|
||||
Message = Annotated[
|
||||
|
@ -267,7 +278,7 @@ class ChatCompletionResponseEvent(BaseModel):
|
|||
stop_reason: StopReason | None = None
|
||||
|
||||
|
||||
class ResponseFormatType(Enum):
|
||||
class ResponseFormatType(StrEnum):
|
||||
"""Types of formats for structured (guided) decoding.
|
||||
|
||||
:cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model.
|
||||
|
@ -286,7 +297,7 @@ class JsonSchemaResponseFormat(BaseModel):
|
|||
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
|
||||
"""
|
||||
|
||||
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
||||
type: Literal[ResponseFormatType.json_schema] = ResponseFormatType.json_schema
|
||||
json_schema: dict[str, Any]
|
||||
|
||||
|
||||
|
@ -298,7 +309,7 @@ class GrammarResponseFormat(BaseModel):
|
|||
:param bnf: The BNF grammar specification the response should conform to
|
||||
"""
|
||||
|
||||
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
||||
type: Literal[ResponseFormatType.grammar] = ResponseFormatType.grammar
|
||||
bnf: dict[str, Any]
|
||||
|
||||
|
||||
|
@ -394,7 +405,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
messages: list[Message]
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
tools: list[ToolDefinition] | None = Field(default_factory=list)
|
||||
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
|
||||
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
||||
|
||||
response_format: ResponseFormat | None = None
|
||||
|
@ -567,14 +578,14 @@ class OpenAIResponseFormatText(BaseModel):
|
|||
@json_schema_type
|
||||
class OpenAIJSONSchema(TypedDict, total=False):
|
||||
name: str
|
||||
description: str | None = None
|
||||
strict: bool | None = None
|
||||
description: str | None
|
||||
strict: bool | None
|
||||
|
||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||
# has one. And, we don't want to alias here because then have to handle
|
||||
# that alias when converting to OpenAI params. So, to support schema,
|
||||
# we use a TypedDict.
|
||||
schema: dict[str, Any] | None = None
|
||||
schema: dict[str, Any] | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -809,15 +820,32 @@ class BatchChatCompletionResponse(BaseModel):
|
|||
batch: list[ChatCompletionResponse]
|
||||
|
||||
|
||||
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
||||
input_messages: list[OpenAIMessageParam]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIChatCompletionResponse(BaseModel):
|
||||
data: list[OpenAICompletionWithInputMessages]
|
||||
has_more: bool
|
||||
first_id: str
|
||||
last_id: str
|
||||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
class Order(Enum):
|
||||
asc = "asc"
|
||||
desc = "desc"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Inference(Protocol):
|
||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
class InferenceProvider(Protocol):
|
||||
"""
|
||||
This protocol defines the interface that should be implemented by all inference providers.
|
||||
"""
|
||||
|
||||
API_NAMESPACE: str = "Inference"
|
||||
|
||||
model_store: ModelStore | None = None
|
||||
|
||||
|
@ -834,13 +862,13 @@ class Inference(Protocol):
|
|||
"""Generate a completion for the given content using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param content: The content to generate a completion for
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding
|
||||
:param content: The content to generate a completion for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: If stream=False, returns a CompletionResponse with the full completion.
|
||||
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk
|
||||
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -853,6 +881,15 @@ class Inference(Protocol):
|
|||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
"""Generate completions for a batch of content using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param content_batch: The content to generate completions for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: A BatchCompletionResponse with the full completions.
|
||||
"""
|
||||
raise NotImplementedError("Batch completion is not implemented")
|
||||
|
||||
@webmethod(route="/inference/chat-completion", method="POST")
|
||||
|
@ -872,9 +909,9 @@ class Inference(Protocol):
|
|||
"""Generate a chat completion for the given messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation
|
||||
:param sampling_params: Parameters to control the sampling strategy
|
||||
:param tools: (Optional) List of tool definitions available to the model
|
||||
:param messages: List of messages in the conversation.
|
||||
:param sampling_params: Parameters to control the sampling strategy.
|
||||
:param tools: (Optional) List of tool definitions available to the model.
|
||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||
.. deprecated::
|
||||
Use tool_config instead.
|
||||
|
@ -891,7 +928,7 @@ class Inference(Protocol):
|
|||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:param tool_config: (Optional) Configuration for tool use.
|
||||
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
|
||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -906,6 +943,17 @@ class Inference(Protocol):
|
|||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
"""Generate chat completions for a batch of messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages_batch: The messages to generate completions for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param tools: (Optional) List of tool definitions available to the model.
|
||||
:param tool_config: (Optional) Configuration for tool use.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: A BatchChatCompletionResponse with the full completions.
|
||||
"""
|
||||
raise NotImplementedError("Batch chat completion is not implemented")
|
||||
|
||||
@webmethod(route="/inference/embeddings", method="POST")
|
||||
|
@ -924,7 +972,7 @@ class Inference(Protocol):
|
|||
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
||||
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
||||
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -956,22 +1004,23 @@ class Inference(Protocol):
|
|||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param prompt: The prompt to generate a completion for
|
||||
:param best_of: (Optional) The number of completions to generate
|
||||
:param echo: (Optional) Whether to echo the prompt
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
||||
:param logit_bias: (Optional) The logit bias to use
|
||||
:param logprobs: (Optional) The log probabilities to use
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate
|
||||
:param n: (Optional) The number of completions to generate
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens
|
||||
:param seed: (Optional) The seed to use
|
||||
:param stop: (Optional) The stop tokens to use
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
:param stream_options: (Optional) The stream options to use
|
||||
:param temperature: (Optional) The temperature to use
|
||||
:param top_p: (Optional) The top p to use
|
||||
:param user: (Optional) The user to use
|
||||
:param prompt: The prompt to generate a completion for.
|
||||
:param best_of: (Optional) The number of completions to generate.
|
||||
:param echo: (Optional) Whether to echo the prompt.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
:returns: An OpenAICompletion.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -1005,27 +1054,64 @@ class Inference(Protocol):
|
|||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
||||
:param function_call: (Optional) The function call to use
|
||||
:param functions: (Optional) List of functions to use
|
||||
:param logit_bias: (Optional) The logit bias to use
|
||||
:param logprobs: (Optional) The log probabilities to use
|
||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate
|
||||
:param n: (Optional) The number of completions to generate
|
||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens
|
||||
:param response_format: (Optional) The response format to use
|
||||
:param seed: (Optional) The seed to use
|
||||
:param stop: (Optional) The stop tokens to use
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
:param stream_options: (Optional) The stream options to use
|
||||
:param temperature: (Optional) The temperature to use
|
||||
:param tool_choice: (Optional) The tool choice to use
|
||||
:param tools: (Optional) The tools to use
|
||||
:param top_logprobs: (Optional) The top log probabilities to use
|
||||
:param top_p: (Optional) The top p to use
|
||||
:param user: (Optional) The user to use
|
||||
:param messages: List of messages in the conversation.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param function_call: (Optional) The function call to use.
|
||||
:param functions: (Optional) List of functions to use.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param response_format: (Optional) The response format to use.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param tool_choice: (Optional) The tool choice to use.
|
||||
:param tools: (Optional) The tools to use.
|
||||
:param top_logprobs: (Optional) The top log probabilities to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
:returns: An OpenAIChatCompletion.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Inference(InferenceProvider):
|
||||
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="GET")
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 20,
|
||||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIChatCompletionResponse:
|
||||
"""List all chat completions.
|
||||
|
||||
:param after: The ID of the last chat completion to return.
|
||||
:param limit: The maximum number of chat completions to return.
|
||||
:param model: The model to filter by.
|
||||
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
|
||||
:returns: A ListOpenAIChatCompletionResponse.
|
||||
"""
|
||||
raise NotImplementedError("List chat completions is not implemented")
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
"""Describe a chat completion by its ID.
|
||||
|
||||
:param completion_id: ID of the chat completion.
|
||||
:returns: A OpenAICompletionWithInputMessages.
|
||||
"""
|
||||
raise NotImplementedError("Get chat completion is not implemented")
|
||||
|
|
|
@ -36,10 +36,25 @@ class ListRoutesResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/inspect/routes", method="GET")
|
||||
async def list_routes(self) -> ListRoutesResponse: ...
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
"""List all routes.
|
||||
|
||||
:returns: A ListRoutesResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/health", method="GET")
|
||||
async def health(self) -> HealthInfo: ...
|
||||
async def health(self) -> HealthInfo:
|
||||
"""Get the health of the service.
|
||||
|
||||
:returns: A HealthInfo.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/version", method="GET")
|
||||
async def version(self) -> VersionInfo: ...
|
||||
async def version(self) -> VersionInfo:
|
||||
"""Get the version of the service.
|
||||
|
||||
:returns: A VersionInfo.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -29,14 +29,14 @@ class ModelType(str, Enum):
|
|||
|
||||
@json_schema_type
|
||||
class Model(CommonModelFields, Resource):
|
||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
||||
type: Literal[ResourceType.model] = ResourceType.model
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_model_id(self) -> str:
|
||||
def provider_model_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
@ -80,16 +80,32 @@ class OpenAIListModelsResponse(BaseModel):
|
|||
@trace_protocol
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models", method="GET")
|
||||
async def list_models(self) -> ListModelsResponse: ...
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
"""List all models.
|
||||
|
||||
:returns: A ListModelsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/models", method="GET")
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse: ...
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List models using the OpenAI API.
|
||||
|
||||
:returns: A OpenAIListModelsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||
async def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> Model: ...
|
||||
) -> Model:
|
||||
"""Get a model by its identifier.
|
||||
|
||||
:param model_id: The identifier of the model to get.
|
||||
:returns: A Model.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models", method="POST")
|
||||
async def register_model(
|
||||
|
@ -99,10 +115,25 @@ class Models(Protocol):
|
|||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model: ...
|
||||
) -> Model:
|
||||
"""Register a model.
|
||||
|
||||
:param model_id: The identifier of the model to register.
|
||||
:param provider_model_id: The identifier of the model in the provider.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param metadata: Any additional metadata for this model.
|
||||
:param model_type: The type of model to register.
|
||||
:returns: A Model.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||
async def unregister_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Unregister a model.
|
||||
|
||||
:param model_id: The identifier of the model to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -182,7 +182,19 @@ class PostTraining(Protocol):
|
|||
),
|
||||
checkpoint_dir: str | None = None,
|
||||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob: ...
|
||||
) -> PostTrainingJob:
|
||||
"""Run supervised fine-tuning of a model.
|
||||
|
||||
:param job_uuid: The UUID of the job to create.
|
||||
:param training_config: The training configuration.
|
||||
:param hyperparam_search_config: The hyperparam search configuration.
|
||||
:param logger_config: The logger configuration.
|
||||
:param model: The model to fine-tune.
|
||||
:param checkpoint_dir: The directory to save checkpoint(s) to.
|
||||
:param algorithm_config: The algorithm configuration.
|
||||
:returns: A PostTrainingJob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
async def preference_optimize(
|
||||
|
@ -193,16 +205,49 @@ class PostTraining(Protocol):
|
|||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
) -> PostTrainingJob:
|
||||
"""Run preference optimization of a model.
|
||||
|
||||
:param job_uuid: The UUID of the job to create.
|
||||
:param finetuned_model: The model to fine-tune.
|
||||
:param algorithm_config: The algorithm configuration.
|
||||
:param training_config: The training configuration.
|
||||
:param hyperparam_search_config: The hyperparam search configuration.
|
||||
:param logger_config: The logger configuration.
|
||||
:returns: A PostTrainingJob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
"""Get all training jobs.
|
||||
|
||||
:returns: A ListPostTrainingJobsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/status", method="GET")
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||
"""Get the status of a training job.
|
||||
|
||||
:param job_uuid: The UUID of the job to get the status of.
|
||||
:returns: A PostTrainingJobStatusResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
"""Cancel a training job.
|
||||
|
||||
:param job_uuid: The UUID of the job to cancel.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||
"""Get the artifacts of a training job.
|
||||
|
||||
:param job_uuid: The UUID of the job to get the artifacts of.
|
||||
:returns: A PostTrainingJobArtifactsResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -32,7 +32,18 @@ class Providers(Protocol):
|
|||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
"""List all available providers.
|
||||
|
||||
:returns: A ListProvidersResponse containing information about all providers.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
"""Get detailed information about a specific provider.
|
||||
|
||||
:param provider_id: The ID of the provider to inspect.
|
||||
:returns: A ProviderInfo object containing the provider's details.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -4,12 +4,23 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
|
||||
class ResourceType(Enum):
|
||||
class StrEnum(str, Enum):
|
||||
"""Backport of StrEnum for Python 3.10 and below."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResourceType(StrEnum):
|
||||
model = "model"
|
||||
shield = "shield"
|
||||
vector_db = "vector_db"
|
||||
|
@ -25,9 +36,9 @@ class Resource(BaseModel):
|
|||
|
||||
identifier: str = Field(description="Unique identifier for this resource in llama stack")
|
||||
|
||||
provider_resource_id: str = Field(
|
||||
description="Unique identifier for this resource in the provider",
|
||||
provider_resource_id: str | None = Field(
|
||||
default=None,
|
||||
description="Unique identifier for this resource in the provider",
|
||||
)
|
||||
|
||||
provider_id: str = Field(description="ID of the provider that owns this resource")
|
||||
|
|
|
@ -53,5 +53,13 @@ class Safety(Protocol):
|
|||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse: ...
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
"""Run a shield.
|
||||
|
||||
:param shield_id: The identifier of the shield to run.
|
||||
:param messages: The messages to run the shield on.
|
||||
:param params: The parameters of the shield.
|
||||
:returns: A RunShieldResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -61,7 +61,15 @@ class Scoring(Protocol):
|
|||
dataset_id: str,
|
||||
scoring_functions: dict[str, ScoringFnParams | None],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse: ...
|
||||
) -> ScoreBatchResponse:
|
||||
"""Score a batch of rows.
|
||||
|
||||
:param dataset_id: The ID of the dataset to score.
|
||||
:param scoring_functions: The scoring functions to use for the scoring.
|
||||
:param save_results_dataset: Whether to save the results to a dataset.
|
||||
:returns: A ScoreBatchResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring/score", method="POST")
|
||||
async def score(
|
||||
|
@ -73,6 +81,6 @@ class Scoring(Protocol):
|
|||
|
||||
:param input_rows: The rows to score.
|
||||
:param scoring_functions: The scoring functions to use for the scoring.
|
||||
:return: ScoreResponse object containing rows and aggregated results
|
||||
:returns: A ScoreResponse object containing rows and aggregated results.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
|
@ -19,18 +21,27 @@ from llama_stack.apis.common.type_system import ParamType
|
|||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""Backport of StrEnum for Python 3.10 and below."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||
# with standard metrics so they can be rolled up?
|
||||
@json_schema_type
|
||||
class ScoringFnParamsType(Enum):
|
||||
class ScoringFnParamsType(StrEnum):
|
||||
llm_as_judge = "llm_as_judge"
|
||||
regex_parser = "regex_parser"
|
||||
basic = "basic"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AggregationFunctionType(Enum):
|
||||
class AggregationFunctionType(StrEnum):
|
||||
average = "average"
|
||||
weighted_average = "weighted_average"
|
||||
median = "median"
|
||||
|
@ -40,36 +51,36 @@ class AggregationFunctionType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
|
||||
judge_model: str
|
||||
prompt_template: str | None = None
|
||||
judge_score_regexes: list[str] | None = Field(
|
||||
judge_score_regexes: list[str] = Field(
|
||||
description="Regexes to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RegexParserScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
||||
parsing_regexes: list[str] | None = Field(
|
||||
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
|
||||
parsing_regexes: list[str] = Field(
|
||||
description="Regex to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BasicScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
@ -99,14 +110,14 @@ class CommonScoringFnFields(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ScoringFn(CommonScoringFnFields, Resource):
|
||||
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
|
||||
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
|
||||
|
||||
@property
|
||||
def scoring_fn_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_scoring_fn_id(self) -> str:
|
||||
def provider_scoring_fn_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
|
@ -123,10 +134,21 @@ class ListScoringFunctionsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring-functions", method="GET")
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
"""List all scoring functions.
|
||||
|
||||
:returns: A ListScoringFunctionsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
|
||||
"""Get a scoring function by its ID.
|
||||
|
||||
:param scoring_fn_id: The ID of the scoring function to get.
|
||||
:returns: A ScoringFn.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST")
|
||||
async def register_scoring_function(
|
||||
|
@ -137,4 +159,14 @@ class ScoringFunctions(Protocol):
|
|||
provider_scoring_fn_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: ScoringFnParams | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Register a scoring function.
|
||||
|
||||
:param scoring_fn_id: The ID of the scoring function to register.
|
||||
:param description: The description of the scoring function.
|
||||
:param return_type: The return type of the scoring function.
|
||||
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
|
||||
:param provider_id: The ID of the provider to use for the scoring function.
|
||||
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -21,14 +21,14 @@ class CommonShieldFields(BaseModel):
|
|||
class Shield(CommonShieldFields, Resource):
|
||||
"""A safety shield resource that can be used to check content"""
|
||||
|
||||
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
|
||||
type: Literal[ResourceType.shield] = ResourceType.shield
|
||||
|
||||
@property
|
||||
def shield_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_shield_id(self) -> str:
|
||||
def provider_shield_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
|
@ -46,10 +46,21 @@ class ListShieldsResponse(BaseModel):
|
|||
@trace_protocol
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields", method="GET")
|
||||
async def list_shields(self) -> ListShieldsResponse: ...
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
"""List all shields.
|
||||
|
||||
:returns: A ListShieldsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||
async def get_shield(self, identifier: str) -> Shield: ...
|
||||
async def get_shield(self, identifier: str) -> Shield:
|
||||
"""Get a shield by its identifier.
|
||||
|
||||
:param identifier: The identifier of the shield to get.
|
||||
:returns: A Shield.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields", method="POST")
|
||||
async def register_shield(
|
||||
|
@ -58,4 +69,13 @@ class Shields(Protocol):
|
|||
provider_shield_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield: ...
|
||||
) -> Shield:
|
||||
"""Register a shield.
|
||||
|
||||
:param shield_id: The identifier of the shield to register.
|
||||
:param provider_shield_id: The identifier of the shield in the provider.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param params: The parameters of the shield.
|
||||
:returns: A Shield.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -37,7 +37,7 @@ class Span(BaseModel):
|
|||
name: str
|
||||
start_time: datetime
|
||||
end_time: datetime | None = None
|
||||
attributes: dict[str, Any] | None = Field(default_factory=dict)
|
||||
attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||
|
||||
def set_attribute(self, key: str, value: Any):
|
||||
if self.attributes is None:
|
||||
|
@ -74,19 +74,19 @@ class EventCommon(BaseModel):
|
|||
trace_id: str
|
||||
span_id: str
|
||||
timestamp: datetime
|
||||
attributes: dict[str, Primitive] | None = Field(default_factory=dict)
|
||||
attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UnstructuredLogEvent(EventCommon):
|
||||
type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value
|
||||
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
|
||||
message: str
|
||||
severity: LogSeverity
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricEvent(EventCommon):
|
||||
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
||||
type: Literal[EventType.METRIC] = EventType.METRIC
|
||||
metric: str # this would be an enum
|
||||
value: int | float
|
||||
unit: str
|
||||
|
@ -131,14 +131,14 @@ class StructuredLogType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class SpanStartPayload(BaseModel):
|
||||
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
|
||||
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
|
||||
name: str
|
||||
parent_span_id: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanEndPayload(BaseModel):
|
||||
type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value
|
||||
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
|
||||
status: SpanStatus
|
||||
|
||||
|
||||
|
@ -151,7 +151,7 @@ register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
|||
|
||||
@json_schema_type
|
||||
class StructuredLogEvent(EventCommon):
|
||||
type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value
|
||||
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
|
||||
payload: StructuredLogPayload
|
||||
|
||||
|
||||
|
@ -203,10 +203,61 @@ class QuerySpanTreeResponse(BaseModel):
|
|||
data: dict[str, SpanWithStatus]
|
||||
|
||||
|
||||
class MetricQueryType(Enum):
|
||||
RANGE = "range"
|
||||
INSTANT = "instant"
|
||||
|
||||
|
||||
class MetricLabelOperator(Enum):
|
||||
EQUALS = "="
|
||||
NOT_EQUALS = "!="
|
||||
REGEX_MATCH = "=~"
|
||||
REGEX_NOT_MATCH = "!~"
|
||||
|
||||
|
||||
class MetricLabelMatcher(BaseModel):
|
||||
name: str
|
||||
value: str
|
||||
operator: MetricLabelOperator = MetricLabelOperator.EQUALS
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricLabel(BaseModel):
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricDataPoint(BaseModel):
|
||||
timestamp: int
|
||||
value: float
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricSeries(BaseModel):
|
||||
metric: str
|
||||
labels: list[MetricLabel]
|
||||
values: list[MetricDataPoint]
|
||||
|
||||
|
||||
class QueryMetricsResponse(BaseModel):
|
||||
data: list[MetricSeries]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/telemetry/events", method="POST")
|
||||
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
|
||||
async def log_event(
|
||||
self,
|
||||
event: Event,
|
||||
ttl_seconds: int = DEFAULT_TTL_DAYS * 86400,
|
||||
) -> None:
|
||||
"""Log an event.
|
||||
|
||||
:param event: The event to log.
|
||||
:param ttl_seconds: The time to live of the event.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces", method="POST")
|
||||
async def query_traces(
|
||||
|
@ -215,13 +266,35 @@ class Telemetry(Protocol):
|
|||
limit: int | None = 100,
|
||||
offset: int | None = 0,
|
||||
order_by: list[str] | None = None,
|
||||
) -> QueryTracesResponse: ...
|
||||
) -> QueryTracesResponse:
|
||||
"""Query traces.
|
||||
|
||||
:param attribute_filters: The attribute filters to apply to the traces.
|
||||
:param limit: The limit of traces to return.
|
||||
:param offset: The offset of the traces to return.
|
||||
:param order_by: The order by of the traces to return.
|
||||
:returns: A QueryTracesResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
"""Get a trace by its ID.
|
||||
|
||||
:param trace_id: The ID of the trace to get.
|
||||
:returns: A Trace.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||
"""Get a span by its ID.
|
||||
|
||||
:param trace_id: The ID of the trace to get the span from.
|
||||
:param span_id: The ID of the span to get.
|
||||
:returns: A Span.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
|
||||
async def get_span_tree(
|
||||
|
@ -229,7 +302,15 @@ class Telemetry(Protocol):
|
|||
span_id: str,
|
||||
attributes_to_return: list[str] | None = None,
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpanTreeResponse: ...
|
||||
) -> QuerySpanTreeResponse:
|
||||
"""Get a span tree by its ID.
|
||||
|
||||
:param span_id: The ID of the span to get the tree from.
|
||||
:param attributes_to_return: The attributes to return in the tree.
|
||||
:param max_depth: The maximum depth of the tree.
|
||||
:returns: A QuerySpanTreeResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans", method="POST")
|
||||
async def query_spans(
|
||||
|
@ -237,7 +318,15 @@ class Telemetry(Protocol):
|
|||
attribute_filters: list[QueryCondition],
|
||||
attributes_to_return: list[str],
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpansResponse: ...
|
||||
) -> QuerySpansResponse:
|
||||
"""Query spans.
|
||||
|
||||
:param attribute_filters: The attribute filters to apply to the spans.
|
||||
:param attributes_to_return: The attributes to return in the spans.
|
||||
:param max_depth: The maximum depth of the tree.
|
||||
:returns: A QuerySpansResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/export", method="POST")
|
||||
async def save_spans_to_dataset(
|
||||
|
@ -246,4 +335,34 @@ class Telemetry(Protocol):
|
|||
attributes_to_save: list[str],
|
||||
dataset_id: str,
|
||||
max_depth: int | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Save spans to a dataset.
|
||||
|
||||
:param attribute_filters: The attribute filters to apply to the spans.
|
||||
:param attributes_to_save: The attributes to save to the dataset.
|
||||
:param dataset_id: The ID of the dataset to save the spans to.
|
||||
:param max_depth: The maximum depth of the tree.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
|
||||
async def query_metrics(
|
||||
self,
|
||||
metric_name: str,
|
||||
start_time: int,
|
||||
end_time: int | None = None,
|
||||
granularity: str | None = "1d",
|
||||
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||
) -> QueryMetricsResponse:
|
||||
"""Query metrics.
|
||||
|
||||
:param metric_name: The name of the metric to query.
|
||||
:param start_time: The start time of the metric to query.
|
||||
:param end_time: The end time of the metric to query.
|
||||
:param granularity: The granularity of the metric to query.
|
||||
:param query_type: The type of query to perform.
|
||||
:param label_matchers: The label matchers to apply to the metric.
|
||||
:returns: A QueryMetricsResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from enum import Enum
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
|
@ -67,11 +67,33 @@ register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
|||
|
||||
@json_schema_type
|
||||
class RAGQueryConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the RAG query generation.
|
||||
|
||||
:param query_generator_config: Configuration for the query generator.
|
||||
:param max_tokens_in_context: Maximum number of tokens in the context.
|
||||
:param max_chunks: Maximum number of chunks to retrieve.
|
||||
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
||||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||
"""
|
||||
|
||||
# This config defines how a query is generated using the messages
|
||||
# for memory bank retrieval.
|
||||
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
|
||||
max_tokens_in_context: int = 4096
|
||||
max_chunks: int = 5
|
||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||
|
||||
@field_validator("chunk_template")
|
||||
def validate_chunk_template(cls, v: str) -> str:
|
||||
if "{chunk.content}" not in v:
|
||||
raise ValueError("chunk_template must contain {chunk.content}")
|
||||
if "{index}" not in v:
|
||||
raise ValueError("chunk_template must contain {index}")
|
||||
if len(v) == 0:
|
||||
raise ValueError("chunk_template must not be empty")
|
||||
return v
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
|
@ -36,7 +36,7 @@ class ToolHost(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||
toolgroup_id: str
|
||||
tool_host: ToolHost
|
||||
description: str
|
||||
|
@ -62,7 +62,7 @@ class ToolGroupInput(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ToolGroup(Resource):
|
||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||
type: Literal[ResourceType.tool_group] = ResourceType.tool_group
|
||||
mcp_endpoint: URL | None = None
|
||||
args: dict[str, Any] | None = None
|
||||
|
||||
|
@ -103,37 +103,65 @@ class ToolGroups(Protocol):
|
|||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Register a tool group"""
|
||||
"""Register a tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to register.
|
||||
:param provider_id: The ID of the provider to use for the tool group.
|
||||
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
||||
:param args: A dictionary of arguments to pass to the tool group.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
|
||||
async def get_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
) -> ToolGroup: ...
|
||||
) -> ToolGroup:
|
||||
"""Get a tool group by its ID.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to get.
|
||||
:returns: A ToolGroup.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups", method="GET")
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
"""List tool groups with optional provider"""
|
||||
"""List tool groups with optional provider.
|
||||
|
||||
:returns: A ListToolGroupsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools", method="GET")
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
"""List tools with optional tool group"""
|
||||
"""List tools with optional tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to list tools for.
|
||||
:returns: A ListToolsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools/{tool_name:path}", method="GET")
|
||||
async def get_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
) -> Tool: ...
|
||||
) -> Tool:
|
||||
"""Get a tool by its name.
|
||||
|
||||
:param tool_name: The name of the tool to get.
|
||||
:returns: A Tool.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
|
||||
async def unregister_toolgroup(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
) -> None:
|
||||
"""Unregister a tool group"""
|
||||
"""Unregister a tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to unregister.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
|
@ -152,9 +180,21 @@ class ToolRuntime(Protocol):
|
|||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse: ...
|
||||
) -> ListToolDefsResponse:
|
||||
"""List all tools in the runtime.
|
||||
|
||||
:param tool_group_id: The ID of the tool group to list tools for.
|
||||
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
||||
:returns: A ListToolDefsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments"""
|
||||
"""Run a tool with the given arguments.
|
||||
|
||||
:param tool_name: The name of the tool to invoke.
|
||||
:param kwargs: A dictionary of arguments to pass to the tool.
|
||||
:returns: A ToolInvocationResult.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
@json_schema_type
|
||||
class VectorDB(Resource):
|
||||
type: Literal[ResourceType.vector_db.value] = ResourceType.vector_db.value
|
||||
type: Literal[ResourceType.vector_db] = ResourceType.vector_db
|
||||
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
|
@ -25,7 +25,7 @@ class VectorDB(Resource):
|
|||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_vector_db_id(self) -> str:
|
||||
def provider_vector_db_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
|
@ -44,13 +44,24 @@ class ListVectorDBsResponse(BaseModel):
|
|||
@trace_protocol
|
||||
class VectorDBs(Protocol):
|
||||
@webmethod(route="/vector-dbs", method="GET")
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
"""List all vector databases.
|
||||
|
||||
:returns: A ListVectorDBsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
|
||||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
) -> VectorDB: ...
|
||||
) -> VectorDB:
|
||||
"""Get a vector database by its identifier.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to get.
|
||||
:returns: A VectorDB.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs", method="POST")
|
||||
async def register_vector_db(
|
||||
|
@ -60,7 +71,22 @@ class VectorDBs(Protocol):
|
|||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorDB: ...
|
||||
) -> VectorDB:
|
||||
"""Register a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to register.
|
||||
:param embedding_model: The embedding model to use.
|
||||
:param embedding_dimension: The dimension of the embedding model.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param provider_vector_db_id: The identifier of the vector database in the provider.
|
||||
:returns: A VectorDB.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
"""Unregister a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -46,7 +46,14 @@ class VectorIO(Protocol):
|
|||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Insert chunks into a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
||||
:param chunks: The chunks to insert.
|
||||
:param ttl_seconds: The time to live of the chunks.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-io/query", method="POST")
|
||||
async def query_chunks(
|
||||
|
@ -54,4 +61,12 @@ class VectorIO(Protocol):
|
|||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse: ...
|
||||
) -> QueryChunksResponse:
|
||||
"""Query chunks from a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to query.
|
||||
:param query: The query to search for.
|
||||
:param params: The parameters of the query.
|
||||
:returns: A QueryChunksResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -38,7 +38,10 @@ class LlamaCLIParser:
|
|||
print_subcommand_description(self.parser, subparsers)
|
||||
|
||||
def parse_args(self) -> argparse.Namespace:
|
||||
return self.parser.parse_args()
|
||||
args = self.parser.parse_args()
|
||||
if not isinstance(args, argparse.Namespace):
|
||||
raise TypeError(f"Expected argparse.Namespace, got {type(args)}")
|
||||
return args
|
||||
|
||||
def run(self, args: argparse.Namespace) -> None:
|
||||
args.func(args)
|
||||
|
|
|
@ -12,6 +12,7 @@ import shutil
|
|||
import sys
|
||||
import textwrap
|
||||
from functools import lru_cache
|
||||
from importlib.abc import Traversable
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
@ -36,7 +37,8 @@ from llama_stack.distribution.datatypes import (
|
|||
)
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
@ -202,7 +204,11 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
else:
|
||||
with open(args.config) as f:
|
||||
try:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
contents = yaml.safe_load(f)
|
||||
contents = replace_env_vars(contents)
|
||||
build_config = BuildConfig(**contents)
|
||||
if args.image_type:
|
||||
build_config.image_type = args.image_type
|
||||
except Exception as e:
|
||||
cprint(
|
||||
f"Could not parse config file {args.config}: {e}",
|
||||
|
@ -245,11 +251,12 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
sys.exit(1)
|
||||
|
||||
if args.run:
|
||||
run_config = Path(run_config)
|
||||
config_dict = yaml.safe_load(run_config.read_text())
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if not os.path.exists(config.external_providers_dir):
|
||||
os.makedirs(config.external_providers_dir, exist_ok=True)
|
||||
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
|
||||
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
|
||||
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
|
||||
run_command(run_args)
|
||||
|
||||
|
||||
|
@ -257,7 +264,7 @@ def _generate_run_config(
|
|||
build_config: BuildConfig,
|
||||
build_dir: Path,
|
||||
image_name: str,
|
||||
) -> str:
|
||||
) -> Path:
|
||||
"""
|
||||
Generate a run.yaml template file for user to edit from a build.yaml file
|
||||
"""
|
||||
|
@ -267,7 +274,9 @@ def _generate_run_config(
|
|||
image_name=image_name,
|
||||
apis=apis,
|
||||
providers={},
|
||||
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None,
|
||||
external_providers_dir=build_config.external_providers_dir
|
||||
if build_config.external_providers_dir
|
||||
else EXTERNAL_PROVIDERS_DIR,
|
||||
)
|
||||
# build providers dict
|
||||
provider_registry = get_provider_registry(build_config)
|
||||
|
@ -334,7 +343,7 @@ def _run_stack_build_command_from_build_config(
|
|||
image_name: str | None = None,
|
||||
template_name: str | None = None,
|
||||
config_path: str | None = None,
|
||||
) -> str:
|
||||
) -> Path | Traversable:
|
||||
image_name = image_name or build_config.image_name
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
if template_name:
|
||||
|
|
|
@ -49,7 +49,7 @@ class StackBuild(Subcommand):
|
|||
type=str,
|
||||
help="Image Type to use for the build. If not specified, will use the image type from the template config.",
|
||||
choices=[e.value for e in ImageType],
|
||||
default=ImageType.CONDA.value,
|
||||
default=None, # no default so we can detect if a user specified --image-type and override image_type in the config
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
|
|
|
@ -46,7 +46,7 @@ class StackListProviders(Subcommand):
|
|||
else:
|
||||
providers = [(k.value, prov) for k, prov in all_providers.items()]
|
||||
|
||||
providers = [p for api, p in providers if api in self.providable_apis]
|
||||
providers = [(api, p) for api, p in providers if api in self.providable_apis]
|
||||
|
||||
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
|
||||
headers = [
|
||||
|
@ -57,7 +57,7 @@ class StackListProviders(Subcommand):
|
|||
|
||||
rows = []
|
||||
|
||||
specs = [spec for p in providers for spec in p.values()]
|
||||
specs = [spec for api, p in providers for spec in p.values()]
|
||||
for spec in specs:
|
||||
if spec.is_sample:
|
||||
continue
|
||||
|
@ -65,7 +65,7 @@ class StackListProviders(Subcommand):
|
|||
[
|
||||
spec.api.value,
|
||||
spec.provider_type,
|
||||
",".join(spec.pip_packages),
|
||||
",".join(spec.pip_packages) if hasattr(spec, "pip_packages") else "",
|
||||
]
|
||||
)
|
||||
print_table(
|
||||
|
|
|
@ -33,7 +33,8 @@ class StackRun(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"config",
|
||||
type=str,
|
||||
help="Path to config file to use for the run",
|
||||
nargs="?", # Make it optional
|
||||
help="Path to config file to use for the run. Required for venv and conda environments.",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--port",
|
||||
|
@ -47,28 +48,12 @@ class StackRun(Subcommand):
|
|||
default=os.environ.get("CONDA_DEFAULT_ENV"),
|
||||
help="Name of the image to run. Defaults to the current environment",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--disable-ipv6",
|
||||
action="store_true",
|
||||
help="Disable IPv6 support",
|
||||
default=False,
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
|
||||
metavar="KEY=VALUE",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--tls-keyfile",
|
||||
type=str,
|
||||
help="Path to TLS key file for HTTPS",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--tls-certfile",
|
||||
type=str,
|
||||
help="Path to TLS certificate file for HTTPS",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--image-type",
|
||||
type=str,
|
||||
|
@ -98,6 +83,13 @@ class StackRun(Subcommand):
|
|||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
|
||||
|
||||
image_type, image_name = self._get_image_type_and_name(args)
|
||||
|
||||
# Check if config is required based on image type
|
||||
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not args.config:
|
||||
self.parser.error("Config file is required for venv and conda environments")
|
||||
|
||||
if args.config:
|
||||
config_file = Path(args.config)
|
||||
has_yaml_suffix = args.config.endswith(".yaml")
|
||||
template_name = None
|
||||
|
@ -131,10 +123,14 @@ class StackRun(Subcommand):
|
|||
|
||||
try:
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if not os.path.exists(str(config.external_providers_dir)):
|
||||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||
except AttributeError as e:
|
||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||
|
||||
image_type, image_name = self._get_image_type_and_name(args)
|
||||
else:
|
||||
config = None
|
||||
config_file = None
|
||||
template_name = None
|
||||
|
||||
# If neither image type nor image name is provided, assume the server should be run directly
|
||||
# using the current environment packages.
|
||||
|
@ -157,9 +153,10 @@ class StackRun(Subcommand):
|
|||
else:
|
||||
run_args = formulate_run_args(image_type, image_name, config, template_name)
|
||||
|
||||
run_args.extend([str(config_file), str(args.port)])
|
||||
if args.disable_ipv6:
|
||||
run_args.append("--disable-ipv6")
|
||||
run_args.extend([str(args.port)])
|
||||
|
||||
if config_file:
|
||||
run_args.extend(["--config", str(config_file)])
|
||||
|
||||
if args.env:
|
||||
for env_var in args.env:
|
||||
|
@ -172,6 +169,4 @@ class StackRun(Subcommand):
|
|||
return
|
||||
run_args.extend(["--env", f"{key}={value}"])
|
||||
|
||||
if args.tls_keyfile and args.tls_certfile:
|
||||
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
|
||||
run_command(run_args)
|
||||
|
|
|
@ -154,6 +154,12 @@ get_python_cmd() {
|
|||
fi
|
||||
}
|
||||
|
||||
# Add other required item commands generic to all containers
|
||||
add_to_container << EOF
|
||||
# Allows running as non-root user
|
||||
RUN mkdir -p /.llama/providers.d /.cache
|
||||
EOF
|
||||
|
||||
if [ -n "$run_config" ]; then
|
||||
# Copy the run config to the build context since it's an absolute path
|
||||
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
|
@ -166,17 +172,19 @@ EOF
|
|||
# and update the configuration to reference the new container path
|
||||
python_cmd=$(get_python_cmd)
|
||||
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
|
||||
if [ -n "$external_providers_dir" ]; then
|
||||
external_providers_dir=$(eval echo "$external_providers_dir")
|
||||
if [ -n "$external_providers_dir" ] && [ -d "$external_providers_dir" ]; then
|
||||
echo "Copying external providers directory: $external_providers_dir"
|
||||
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
|
||||
add_to_container << EOF
|
||||
COPY $external_providers_dir /app/providers.d
|
||||
COPY providers.d /.llama/providers.d
|
||||
EOF
|
||||
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d
|
||||
# Edit the run.yaml file to change the external_providers_dir to /.llama/providers.d
|
||||
if [ "$(uname)" = "Darwin" ]; then
|
||||
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
|
||||
else
|
||||
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
sed -i 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
@ -255,9 +263,6 @@ fi
|
|||
# Add other require item commands genearic to all containers
|
||||
add_to_container << EOF
|
||||
|
||||
# Allows running as non-root user
|
||||
RUN mkdir -p /.llama /.cache
|
||||
|
||||
RUN chmod -R g+rw /app /.llama /.cache
|
||||
EOF
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.distribution.distribution import (
|
|||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
@ -73,11 +74,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
|
||||
existing_providers = config.providers.get(api_str, [])
|
||||
if existing_providers:
|
||||
logger.info(
|
||||
f"Re-configuring existing providers for API `{api_str}`...",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
|
||||
updated_providers = []
|
||||
for p in existing_providers:
|
||||
logger.info(f"> Configuring provider `({p.provider_type})`")
|
||||
|
@ -91,7 +88,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
if not plist:
|
||||
raise ValueError(f"No provider configured for API {api_str}?")
|
||||
|
||||
logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||
logger.info(f"Configuring API `{api_str}`...")
|
||||
updated_providers = []
|
||||
for i, provider_type in enumerate(plist):
|
||||
if i >= 1:
|
||||
|
@ -174,4 +171,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
|
|||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
if not config_dict.get("external_providers_dir", None):
|
||||
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
||||
|
||||
return StackRunConfig(**config_dict)
|
||||
|
|
|
@ -5,9 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
@ -249,10 +250,18 @@ class ServerConfig(BaseModel):
|
|||
default=None,
|
||||
description="Path to TLS key file for HTTPS",
|
||||
)
|
||||
tls_cafile: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS CA file for HTTPS with mutual TLS authentication",
|
||||
)
|
||||
auth: AuthenticationConfig | None = Field(
|
||||
default=None,
|
||||
description="Authentication configuration for the server",
|
||||
)
|
||||
host: str | None = Field(
|
||||
default=None,
|
||||
description="The host the server should listen on",
|
||||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
|
@ -304,11 +313,20 @@ a default SQLite store will be used.""",
|
|||
description="Configuration for the HTTP(S) server",
|
||||
)
|
||||
|
||||
external_providers_dir: str | None = Field(
|
||||
external_providers_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
return Path(v)
|
||||
return v
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
@ -322,8 +340,17 @@ class BuildConfig(BaseModel):
|
|||
default=None,
|
||||
description="Name of the distribution to build",
|
||||
)
|
||||
external_providers_dir: str | None = Field(
|
||||
external_providers_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||
"pip_packages MUST contain the provider package name.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
return Path(v)
|
||||
return v
|
||||
|
|
|
@ -145,7 +145,7 @@ def get_provider_registry(
|
|||
|
||||
# Check if config has the external_providers_dir attribute
|
||||
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
||||
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||
|
|
|
@ -30,7 +30,7 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
|
@ -216,7 +216,19 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
"yellow",
|
||||
)
|
||||
if self.config_path_or_template_name.endswith(".yaml"):
|
||||
print_pip_install_help(self.config.providers)
|
||||
# Convert Provider objects to their types
|
||||
provider_types: dict[str, str | list[str]] = {}
|
||||
for api, providers in self.config.providers.items():
|
||||
types = [p.provider_type for p in providers]
|
||||
# Convert single-item lists to strings
|
||||
provider_types[api] = types[0] if len(types) == 1 else types
|
||||
build_config = BuildConfig(
|
||||
distribution_spec=DistributionSpec(
|
||||
providers=provider_types,
|
||||
),
|
||||
external_providers_dir=self.config.external_providers_dir,
|
||||
)
|
||||
print_pip_install_help(build_config)
|
||||
else:
|
||||
prefix = "!" if in_notebook() else ""
|
||||
cprint(
|
||||
|
|
|
@ -99,7 +99,7 @@ class ProviderImpl(Providers):
|
|||
try:
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
return api_name, health
|
||||
except asyncio.TimeoutError:
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
|
|
|
@ -44,7 +44,8 @@ class RequestProviderDataContext(AbstractContextManager):
|
|||
class NeedsRequestProviderData:
|
||||
def get_request_provider_data(self) -> Any:
|
||||
spec = self.__provider_spec__
|
||||
assert spec, f"Provider spec not set on {self.__class__}"
|
||||
if not spec:
|
||||
raise ValueError(f"Provider spec not set on {self.__class__}")
|
||||
|
||||
provider_type = spec.provider_type
|
||||
validator_class = spec.provider_data_validator
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
|
|||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference import Inference, InferenceProvider
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
|
@ -83,6 +83,13 @@ def api_protocol_map() -> dict[Api, Any]:
|
|||
}
|
||||
|
||||
|
||||
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
||||
return {
|
||||
**api_protocol_map(),
|
||||
Api.inference: InferenceProvider,
|
||||
}
|
||||
|
||||
|
||||
def additional_protocols_map() -> dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
|
@ -302,9 +309,6 @@ async def instantiate_provider(
|
|||
inner_impls: dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
):
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
provider_spec = provider.spec
|
||||
if not hasattr(provider_spec, "module"):
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
@ -342,6 +346,8 @@ async def instantiate_provider(
|
|||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
protocols = api_protocol_map_for_compliance_check()
|
||||
additional_protocols = additional_protocols_map()
|
||||
# TODO: check compliance for special tool groups
|
||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||
|
|
|
@ -573,6 +573,12 @@ class InferenceRouter(Inference):
|
|||
for tool in tools:
|
||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||
|
||||
# Some providers make tool calls even when tool_choice is "none"
|
||||
# so just clear them both out to avoid unexpected tool calls
|
||||
if tool_choice == "none" and tools is not None:
|
||||
tool_choice = None
|
||||
tools = None
|
||||
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
messages=messages,
|
||||
|
@ -600,7 +606,19 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
if stream:
|
||||
return await provider.openai_chat_completion(**params)
|
||||
else:
|
||||
return await self._nonstream_openai_chat_completion(provider, params)
|
||||
|
||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
||||
response = await provider.openai_chat_completion(**params)
|
||||
for choice in response.choices:
|
||||
# some providers return an empty list for no tool calls in non-streaming responses
|
||||
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
||||
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
|
||||
choice.message.tool_calls = None
|
||||
return response
|
||||
|
||||
async def health(self) -> dict[str, HealthResponse]:
|
||||
health_statuses = {}
|
||||
|
@ -612,7 +630,7 @@ class InferenceRouter(Inference):
|
|||
continue
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
health_statuses[provider_id] = health
|
||||
except asyncio.TimeoutError:
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"Health check timed out after {timeout} seconds",
|
||||
|
|
|
@ -93,7 +93,7 @@ class AuthenticationMiddleware:
|
|||
|
||||
# Validate token and get access attributes
|
||||
try:
|
||||
access_attributes = await self.auth_provider.validate_token(token, scope)
|
||||
validation_result = await self.auth_provider.validate_token(token, scope)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
return await self._send_auth_error(send, "Authentication service timeout")
|
||||
|
@ -105,17 +105,20 @@ class AuthenticationMiddleware:
|
|||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if access_attributes:
|
||||
user_attributes = access_attributes.model_dump(exclude_none=True)
|
||||
if validation_result.access_attributes:
|
||||
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to token by default")
|
||||
user_attributes = {
|
||||
"namespaces": [token],
|
||||
"roles": [token],
|
||||
}
|
||||
|
||||
# Store attributes in request scope
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
|
||||
scope["principal"] = validation_result.principal
|
||||
logger.debug(
|
||||
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
|
||||
)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
|
|
|
@ -5,12 +5,14 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -18,9 +20,11 @@ from llama_stack.log import get_logger
|
|||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
class TokenValidationResult(BaseModel):
|
||||
principal: str | None = Field(
|
||||
default=None,
|
||||
description="The principal (username or persistent identifier) of the authenticated user",
|
||||
)
|
||||
access_attributes: AccessAttributes | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
@ -43,6 +47,10 @@ class AuthResponse(BaseModel):
|
|||
""",
|
||||
)
|
||||
|
||||
|
||||
class AuthResponse(TokenValidationResult):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
message: str | None = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
@ -69,6 +77,7 @@ class AuthProviderType(str, Enum):
|
|||
|
||||
KUBERNETES = "kubernetes"
|
||||
CUSTOM = "custom"
|
||||
OAUTH2_TOKEN = "oauth2_token"
|
||||
|
||||
|
||||
class AuthProviderConfig(BaseModel):
|
||||
|
@ -82,7 +91,7 @@ class AuthProvider(ABC):
|
|||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
"""Validate a token and return access attributes."""
|
||||
pass
|
||||
|
||||
|
@ -92,12 +101,16 @@ class AuthProvider(ABC):
|
|||
pass
|
||||
|
||||
|
||||
class KubernetesAuthProviderConfig(BaseModel):
|
||||
api_server_url: str
|
||||
ca_cert_path: str | None = None
|
||||
|
||||
|
||||
class KubernetesAuthProvider(AuthProvider):
|
||||
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||
|
||||
def __init__(self, config: dict[str, str]):
|
||||
self.api_server_url = config["api_server_url"]
|
||||
self.ca_cert_path = config.get("ca_cert_path")
|
||||
def __init__(self, config: KubernetesAuthProviderConfig):
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def _get_client(self):
|
||||
|
@ -110,16 +123,16 @@ class KubernetesAuthProvider(AuthProvider):
|
|||
|
||||
# Configure the client
|
||||
configuration = client.Configuration()
|
||||
configuration.host = self.api_server_url
|
||||
if self.ca_cert_path:
|
||||
configuration.ssl_ca_cert = self.ca_cert_path
|
||||
configuration.verify_ssl = bool(self.ca_cert_path)
|
||||
configuration.host = self.config.api_server_url
|
||||
if self.config.ca_cert_path:
|
||||
configuration.ssl_ca_cert = self.config.ca_cert_path
|
||||
configuration.verify_ssl = bool(self.config.ca_cert_path)
|
||||
|
||||
# Create API client
|
||||
self._client = ApiClient(configuration)
|
||||
return self._client
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
"""Validate a Kubernetes token and return access attributes."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
|
@ -146,9 +159,12 @@ class KubernetesAuthProvider(AuthProvider):
|
|||
username = payload.get("sub", "")
|
||||
groups = payload.get("groups", [])
|
||||
|
||||
return AccessAttributes(
|
||||
return TokenValidationResult(
|
||||
principal=username,
|
||||
access_attributes=AccessAttributes(
|
||||
roles=[username], # Use username as a role
|
||||
teams=groups, # Use Kubernetes groups as teams
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -162,18 +178,125 @@ class KubernetesAuthProvider(AuthProvider):
|
|||
self._client = None
|
||||
|
||||
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
||||
attributes = AccessAttributes()
|
||||
for claim_key, attribute_key in mapping.items():
|
||||
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
||||
continue
|
||||
claim = claims[claim_key]
|
||||
if isinstance(claim, list):
|
||||
values = claim
|
||||
else:
|
||||
values = claim.split()
|
||||
|
||||
current = getattr(attributes, attribute_key)
|
||||
if current:
|
||||
current.extend(values)
|
||||
else:
|
||||
setattr(attributes, attribute_key, values)
|
||||
return attributes
|
||||
|
||||
|
||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
jwks_uri: str
|
||||
cache_ttl: int = 3600
|
||||
audience: str = "llama-stack"
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"sub": "roles",
|
||||
"username": "roles",
|
||||
"groups": "teams",
|
||||
"team": "teams",
|
||||
"project": "projects",
|
||||
"tenant": "namespaces",
|
||||
"namespace": "namespaces",
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@field_validator("claims_mapping")
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
if value not in AccessAttributes.model_fields:
|
||||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
||||
return v
|
||||
|
||||
|
||||
class OAuth2TokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
||||
|
||||
This should be the standard authentication provider for most use cases.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OAuth2TokenAuthProviderConfig):
|
||||
self.config = config
|
||||
self._jwks_at: float = 0.0
|
||||
self._jwks: dict[str, str] = {}
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
"""Validate a token using the JWT token."""
|
||||
await self._refresh_jwks()
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
kid = header["kid"]
|
||||
if kid not in self._jwks:
|
||||
raise ValueError(f"Unknown key ID: {kid}")
|
||||
key_data = self._jwks[kid]
|
||||
algorithm = header.get("alg", "RS256")
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
key_data,
|
||||
algorithms=[algorithm],
|
||||
audience=self.config.audience,
|
||||
options={"verify_exp": True},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Invalid JWT token: {token}") from exc
|
||||
|
||||
# There are other standard claims, the most relevant of which is `scope`.
|
||||
# We should incorporate these into the access attributes.
|
||||
principal = claims["sub"]
|
||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||
return TokenValidationResult(
|
||||
principal=principal,
|
||||
access_attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
|
||||
async def _refresh_jwks(self) -> None:
|
||||
if time.time() - self._jwks_at > self.config.cache_ttl:
|
||||
async with httpx.AsyncClient() as client:
|
||||
res = await client.get(self.config.jwks_uri, timeout=5)
|
||||
res.raise_for_status()
|
||||
jwks_data = res.json()["keys"]
|
||||
self._jwks = {}
|
||||
for k in jwks_data:
|
||||
kid = k["kid"]
|
||||
# Store the entire key object as it may be needed for different algorithms
|
||||
self._jwks[kid] = k
|
||||
self._jwks_at = time.time()
|
||||
|
||||
|
||||
class CustomAuthProviderConfig(BaseModel):
|
||||
endpoint: str
|
||||
|
||||
|
||||
class CustomAuthProvider(AuthProvider):
|
||||
"""Custom authentication provider that uses an external endpoint."""
|
||||
|
||||
def __init__(self, config: dict[str, str]):
|
||||
self.endpoint = config["endpoint"]
|
||||
def __init__(self, config: CustomAuthProviderConfig):
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
|
||||
"""Validate a token using the custom authentication endpoint."""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Authentication endpoint not configured")
|
||||
|
||||
if scope is None:
|
||||
scope = {}
|
||||
|
||||
|
@ -202,7 +325,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.endpoint,
|
||||
self.config.endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
|
@ -214,19 +337,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
return auth_response.access_attributes
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||
user_attributes = {
|
||||
"namespaces": [token],
|
||||
}
|
||||
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
return auth_response.access_attributes
|
||||
return auth_response
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing authentication response")
|
||||
raise ValueError("Invalid authentication response format") from e
|
||||
|
@ -253,9 +364,11 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
|
|||
provider_type = config.provider_type.lower()
|
||||
|
||||
if provider_type == "kubernetes":
|
||||
return KubernetesAuthProvider(config.config)
|
||||
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "custom":
|
||||
return CustomAuthProvider(config.config)
|
||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "oauth2_token":
|
||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||
else:
|
||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
||||
|
|
|
@ -9,6 +9,7 @@ import asyncio
|
|||
import inspect
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
@ -17,6 +18,7 @@ from importlib.metadata import version as parse_version
|
|||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
|
||||
import rich.pretty
|
||||
import yaml
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi import Path as FastapiPath
|
||||
|
@ -114,7 +116,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
return HTTPException(status_code=400, detail=str(exc))
|
||||
elif isinstance(exc, PermissionError):
|
||||
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
||||
elif isinstance(exc, TimeoutError):
|
||||
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
||||
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
||||
|
@ -139,7 +141,7 @@ async def shutdown(app):
|
|||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning("No shutdown method for %s", impl_name)
|
||||
except asyncio.TimeoutError:
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
|
||||
|
@ -186,11 +188,30 @@ async def sse_generator(event_gen_coroutine):
|
|||
)
|
||||
|
||||
|
||||
async def log_request_pre_validation(request: Request):
|
||||
if request.method in ("POST", "PUT", "PATCH"):
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
if body_bytes:
|
||||
try:
|
||||
parsed_body = json.loads(body_bytes.decode())
|
||||
log_output = rich.pretty.pretty_repr(parsed_body)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
log_output = repr(body_bytes)
|
||||
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
||||
else:
|
||||
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user_attributes):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
@ -337,22 +358,11 @@ def main(args: argparse.Namespace | None = None):
|
|||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tls-keyfile",
|
||||
help="Path to TLS key file for HTTPS",
|
||||
required="--tls-certfile" in sys.argv,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tls-certfile",
|
||||
help="Path to TLS certificate file for HTTPS",
|
||||
required="--tls-keyfile" in sys.argv,
|
||||
)
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
|
@ -361,9 +371,9 @@ def main(args: argparse.Namespace | None = None):
|
|||
args = parser.parse_args()
|
||||
|
||||
# Check for deprecated argument usage
|
||||
if "--yaml-config" in sys.argv:
|
||||
if "--config" in sys.argv:
|
||||
warnings.warn(
|
||||
"The '--yaml-config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
|
||||
"The '--config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
@ -381,7 +391,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
raise ValueError(f"Template {args.template} does not exist")
|
||||
log_line = f"Using template {args.template} config file: {config_file}"
|
||||
else:
|
||||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
raise ValueError("Either --config or --template must be provided")
|
||||
|
||||
logger_config = None
|
||||
with open(config_file) as fp:
|
||||
|
@ -486,10 +496,6 @@ def main(args: argparse.Namespace | None = None):
|
|||
port = args.port or config.server.port
|
||||
|
||||
ssl_config = None
|
||||
if args.tls_keyfile:
|
||||
keyfile = args.tls_keyfile
|
||||
certfile = args.tls_certfile
|
||||
else:
|
||||
keyfile = config.server.tls_keyfile
|
||||
certfile = config.server.tls_certfile
|
||||
|
||||
|
@ -498,9 +504,16 @@ def main(args: argparse.Namespace | None = None):
|
|||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
if config.server.tls_cafile:
|
||||
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
logger.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
|
||||
listen_host = config.server.host or ["::", "0.0.0.0"]
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
|
|
|
@ -29,7 +29,7 @@ error_handler() {
|
|||
trap 'error_handler ${LINENO}' ERR
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: $0 <env_type> <env_path_or_name> <yaml_config> <port> <script_args...>"
|
||||
echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -40,23 +40,30 @@ env_path_or_name="$1"
|
|||
container_image="localhost/$env_path_or_name"
|
||||
shift
|
||||
|
||||
yaml_config="$1"
|
||||
shift
|
||||
|
||||
port="$1"
|
||||
shift
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
# Initialize env_vars as an string
|
||||
# Initialize variables
|
||||
yaml_config=""
|
||||
env_vars=""
|
||||
other_args=""
|
||||
# Process environment variables from --env arguments
|
||||
|
||||
# Process remaining arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--config)
|
||||
if [[ -n "$2" ]]; then
|
||||
yaml_config="$2"
|
||||
shift 2
|
||||
else
|
||||
echo -e "${RED}Error: $1 requires a CONFIG argument${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
--env)
|
||||
|
||||
if [[ -n "$2" ]]; then
|
||||
env_vars="$env_vars --env $2"
|
||||
shift 2
|
||||
|
@ -71,6 +78,13 @@ while [[ $# -gt 0 ]]; do
|
|||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check if yaml_config is required based on env_type
|
||||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]] && [ -z "$yaml_config" ]; then
|
||||
echo -e "${RED}Error: --config is required for venv and conda environments${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PYTHON_BINARY="python"
|
||||
case "$env_type" in
|
||||
"venv")
|
||||
|
@ -106,8 +120,14 @@ esac
|
|||
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
|
||||
set -x
|
||||
|
||||
if [ -n "$yaml_config" ]; then
|
||||
yaml_config_arg="--config $yaml_config"
|
||||
else
|
||||
yaml_config_arg=""
|
||||
fi
|
||||
|
||||
$PYTHON_BINARY -m llama_stack.distribution.server.server \
|
||||
--yaml-config "$yaml_config" \
|
||||
$yaml_config_arg \
|
||||
--port "$port" \
|
||||
$env_vars \
|
||||
$other_args
|
||||
|
@ -149,15 +169,26 @@ elif [[ "$env_type" == "container" ]]; then
|
|||
version_tag=$(curl -s $URL | jq -r '.info.version')
|
||||
fi
|
||||
|
||||
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||
# Build the command with optional yaml config
|
||||
cmd="$CONTAINER_BINARY run $CONTAINER_OPTS -it \
|
||||
-p $port:$port \
|
||||
$env_vars \
|
||||
-v "$yaml_config:/app/config.yaml" \
|
||||
$mounts \
|
||||
--env LLAMA_STACK_PORT=$port \
|
||||
--entrypoint python \
|
||||
$container_image:$version_tag \
|
||||
-m llama_stack.distribution.server.server \
|
||||
--yaml-config /app/config.yaml \
|
||||
$other_args
|
||||
-m llama_stack.distribution.server.server"
|
||||
|
||||
# Add yaml config if provided, otherwise use default
|
||||
if [ -n "$yaml_config" ]; then
|
||||
cmd="$cmd -v $yaml_config:/app/run.yaml --config /app/run.yaml"
|
||||
else
|
||||
cmd="$cmd --config /app/run.yaml"
|
||||
fi
|
||||
|
||||
# Add any other args
|
||||
cmd="$cmd $other_args"
|
||||
|
||||
# Execute the command
|
||||
eval $cmd
|
||||
fi
|
||||
|
|
|
@ -73,7 +73,7 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
|
||||
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
||||
start_key, end_key = _get_registry_key_range()
|
||||
values = await self.kvstore.range(start_key, end_key)
|
||||
values = await self.kvstore.values_in_range(start_key, end_key)
|
||||
return _parse_registry_values(values)
|
||||
|
||||
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
|
@ -134,7 +134,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
|||
return
|
||||
|
||||
start_key, end_key = _get_registry_key_range()
|
||||
values = await self.kvstore.range(start_key, end_key)
|
||||
values = await self.kvstore.values_in_range(start_key, end_key)
|
||||
objects = _parse_registry_values(values)
|
||||
|
||||
async with self._locked_cache() as cache:
|
||||
|
|
|
@ -124,7 +124,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
|
|||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
else:
|
||||
full_response = response
|
||||
message_placeholder.markdown(full_response.completion_message.content)
|
||||
full_response = response.completion_message.content
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
|
|
|
@ -14,3 +14,5 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
|||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||
|
||||
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
|
||||
|
||||
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"
|
||||
|
|
|
@ -22,8 +22,10 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
|||
|
||||
def formulate_run_args(image_type, image_name, config, template_name) -> list:
|
||||
env_name = ""
|
||||
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
|
||||
env_name = f"distribution-{template_name}" if template_name else config.container_image
|
||||
if image_type == LlamaStackImageType.CONTAINER.value:
|
||||
env_name = (
|
||||
f"distribution-{template_name}" if template_name else (config.container_image if config else image_name)
|
||||
)
|
||||
elif image_type == LlamaStackImageType.CONDA.value:
|
||||
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
|
||||
env_name = image_name or current_conda_env
|
||||
|
|
|
@ -245,7 +245,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{"function_description": self._gen_function_description(custom_tools)},
|
||||
)
|
||||
|
||||
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
||||
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> str:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
@ -286,10 +286,12 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
|
||||
"""
|
||||
)
|
||||
return PromptTemplate(
|
||||
template = PromptTemplate(
|
||||
template_str.strip("\n"),
|
||||
{"tools": [t.model_dump() for t in custom_tools]},
|
||||
).render()
|
||||
)
|
||||
rendered: str = template.render()
|
||||
return rendered
|
||||
|
||||
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||
return [
|
||||
|
|
|
@ -173,9 +173,7 @@ INCORRECT: [get_events(location="Singapore")] <- If function not in list
|
|||
- Don't repeat tool response verbatim
|
||||
- Don't add supplementary information
|
||||
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke:
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
|
@ -196,10 +194,7 @@ Here is a list of functions in JSON format that you can invoke.
|
|||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
You can answer general questions or invoke tools when necessary.
|
||||
In addition to tool calls, you should also augment your responses by using the tool outputs.<|eot|><|header_start|>user<|header_end|>
|
||||
]<|eot|><|header_start|>user<|header_end|>
|
||||
|
||||
What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|>
|
||||
|
||||
|
|
|
@ -61,7 +61,6 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
- Don't repeat tool response verbatim
|
||||
- Don't add supplementary information
|
||||
|
||||
|
||||
{{ function_description }}
|
||||
""".strip("\n")
|
||||
)
|
||||
|
@ -76,8 +75,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke:
|
||||
[
|
||||
{% for t in tools -%}
|
||||
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||
|
@ -108,10 +106,6 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{% endif -%}
|
||||
{%- endfor %}
|
||||
]
|
||||
|
||||
You can answer general questions or invoke tools when necessary.
|
||||
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
||||
|
||||
"""
|
||||
)
|
||||
return PromptTemplate(
|
||||
|
|
|
@ -948,6 +948,8 @@ def llama_meta_net_info(model: Model) -> LlamaDownloadInfo:
|
|||
elif model.core_model_id == CoreModelId.llama_guard_2_8b:
|
||||
folder = "llama-guard-2"
|
||||
else:
|
||||
if model.huggingface_repo is None:
|
||||
raise ValueError(f"Model {model.core_model_id} has no huggingface_repo set")
|
||||
folder = model.huggingface_repo.split("/")[-1]
|
||||
if "Llama-2" in folder:
|
||||
folder = folder.lower()
|
||||
|
@ -1024,3 +1026,4 @@ def llama_meta_pth_size(model: Model) -> int:
|
|||
return 54121549657
|
||||
else:
|
||||
return 100426653046
|
||||
return 0
|
||||
|
|
|
@ -95,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_groups_api: ToolGroups,
|
||||
vector_io_api: VectorIO,
|
||||
persistence_store: KVStore,
|
||||
created_at: str,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
|
@ -104,6 +105,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.created_at = created_at
|
||||
|
||||
ShieldRunnerMixin.__init__(
|
||||
self,
|
||||
|
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
|
@ -20,14 +20,13 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListAgentSessionsResponse,
|
||||
ListAgentsResponse,
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
|
@ -39,13 +38,14 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
from .openai_responses import OpenAIResponsesImpl
|
||||
from .persistence import AgentInfo
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class MetaReferenceAgentsImpl(Agents):
|
||||
|
@ -82,43 +82,47 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse:
|
||||
agent_id = str(uuid.uuid4())
|
||||
created_at = datetime.now(timezone.utc)
|
||||
|
||||
agent_info = AgentInfo(
|
||||
**agent_config.model_dump(),
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
# Store the agent info
|
||||
await self.persistence_store.set(
|
||||
key=f"agent:{agent_id}",
|
||||
value=agent_config.model_dump_json(),
|
||||
value=agent_info.model_dump_json(),
|
||||
)
|
||||
|
||||
return AgentCreateResponse(
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||
agent_config = await self.persistence_store.get(
|
||||
agent_info_json = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
if not agent_config:
|
||||
raise ValueError(f"Could not find agent config for {agent_id}")
|
||||
if not agent_info_json:
|
||||
raise ValueError(f"Could not find agent info for {agent_id}")
|
||||
|
||||
try:
|
||||
agent_config = json.loads(agent_config)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
|
||||
|
||||
try:
|
||||
agent_config = AgentConfig(**agent_config)
|
||||
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
|
||||
raise ValueError(f"Could not validate agent info for {agent_id}") from e
|
||||
|
||||
return ChatAgent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_config,
|
||||
agent_config=agent_info,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
vector_io_api=self.vector_io_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
persistence_store=(
|
||||
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
|
||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
created_at=agent_info.created_at,
|
||||
)
|
||||
|
||||
async def create_agent_session(
|
||||
|
@ -212,6 +216,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
@ -226,24 +231,75 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
)
|
||||
|
||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
# Delete turns first, then the session
|
||||
await agent.storage.delete_session_turns(session_id)
|
||||
await agent.storage.delete_session(session_id)
|
||||
|
||||
async def delete_agent(self, agent_id: str) -> None:
|
||||
# First get all sessions for this agent
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
sessions = await agent.storage.list_sessions()
|
||||
|
||||
# Delete all sessions
|
||||
for session in sessions:
|
||||
await self.delete_agents_session(agent_id, session.session_id)
|
||||
|
||||
# Finally delete the agent itself
|
||||
await self.persistence_store.delete(f"agent:{agent_id}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
|
||||
agent_list: list[Agent] = []
|
||||
for agent_key in agent_keys:
|
||||
agent_id = agent_key.split(":")[1]
|
||||
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
pass
|
||||
# Get the agent info using the key
|
||||
agent_info_json = await self.persistence_store.get(agent_key)
|
||||
if not agent_info_json:
|
||||
logger.error(f"Could not find agent info for key {agent_key}")
|
||||
continue
|
||||
|
||||
try:
|
||||
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||
agent_list.append(
|
||||
Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_info,
|
||||
created_at=agent_info.created_at,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing agent info for {agent_id}: {e}")
|
||||
continue
|
||||
|
||||
# Convert Agent objects to dictionaries
|
||||
agent_dicts = [agent.model_dump() for agent in agent_list]
|
||||
return paginate_records(agent_dicts, start_index, limit)
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
pass
|
||||
chat_agent = await self._get_agent_impl(agent_id)
|
||||
agent = Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=chat_agent.agent_config,
|
||||
created_at=chat_agent.created_at,
|
||||
)
|
||||
return agent
|
||||
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
self, agent_id: str, start_index: int | None = None, limit: int | None = None
|
||||
) -> PaginatedResponse:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
sessions = await agent.storage.list_sessions()
|
||||
# Convert Session objects to dictionaries
|
||||
session_dicts = [session.model_dump() for session in sessions]
|
||||
return paginate_records(session_dicts, start_index, limit)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
# OpenAI responses
|
||||
|
@ -255,7 +311,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
|
|
@ -7,22 +7,29 @@
|
|||
import json
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputItemList,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessage,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
|
@ -32,10 +39,13 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
@ -50,31 +60,110 @@ logger = get_logger(name=__name__, category="openai_responses")
|
|||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]:
|
||||
async def _convert_response_content_to_chat_content(
|
||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
||||
The content schemas of each API look similar, but are not exactly the same.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
|
||||
if content_part.image_url:
|
||||
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
|
||||
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
elif isinstance(content_part, str):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
|
||||
)
|
||||
return converted_parts
|
||||
|
||||
|
||||
async def _convert_response_input_to_chat_messages(
|
||||
input: str | list[OpenAIResponseInput],
|
||||
) -> list[OpenAIMessageParam]:
|
||||
"""
|
||||
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
|
||||
"""
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
for output_message in previous_response.output:
|
||||
if isinstance(output_message, OpenAIResponseOutputMessage):
|
||||
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
||||
if isinstance(input, list):
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
tool_call_id=input_item.call_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_item.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_item.name,
|
||||
arguments=input_item.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
else:
|
||||
content = await _convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await _get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
||||
|
||||
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]:
|
||||
output_messages = []
|
||||
for choice in choices:
|
||||
async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
|
||||
"""
|
||||
Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
|
||||
"""
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
output_content = choice.message.content
|
||||
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||
output_content = choice.message.content.text
|
||||
# TODO: handle image content
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessage(
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
|
||||
)
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
)
|
||||
return output_messages
|
||||
|
||||
|
||||
async def _get_message_type_by_role(role: str):
|
||||
role_to_type = {
|
||||
"user": OpenAIUserMessageParam,
|
||||
"system": OpenAISystemMessageParam,
|
||||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
|
||||
|
||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||
input_items: OpenAIResponseInputItemList
|
||||
response: OpenAIResponseObject
|
||||
|
||||
|
||||
class OpenAIResponsesImpl:
|
||||
|
@ -90,19 +179,45 @@ class OpenAIResponsesImpl:
|
|||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems:
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
|
||||
response_json = await self.persistence_store.get(key=key)
|
||||
if response_json is None:
|
||||
raise ValueError(f"OpenAI response with id '{id}' not found")
|
||||
return OpenAIResponseObject.model_validate_json(response_json)
|
||||
return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json)
|
||||
|
||||
async def _prepend_previous_response(
|
||||
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
|
||||
):
|
||||
if previous_response_id:
|
||||
previous_response_with_input = await self._get_previous_response_with_input(previous_response_id)
|
||||
|
||||
# previous response input items
|
||||
new_input_items = previous_response_with_input.input_items.data
|
||||
|
||||
# previous response output items
|
||||
new_input_items.extend(previous_response_with_input.response.output)
|
||||
|
||||
# new input items from the current request
|
||||
if isinstance(input, str):
|
||||
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
|
||||
else:
|
||||
new_input_items.extend(input)
|
||||
|
||||
input = new_input_items
|
||||
|
||||
return input
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
response_with_input = await self._get_previous_response_with_input(id)
|
||||
return response_with_input.response
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInputMessage],
|
||||
input: str | list[OpenAIResponseInput],
|
||||
model: str,
|
||||
previous_response_id: str | None = None,
|
||||
store: bool | None = True,
|
||||
|
@ -112,31 +227,8 @@ class OpenAIResponsesImpl:
|
|||
):
|
||||
stream = False if stream is None else stream
|
||||
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if previous_response_id:
|
||||
previous_response = await self.get_openai_response(previous_response_id)
|
||||
messages.extend(await _previous_response_to_messages(previous_response))
|
||||
# TODO: refactor this user_content parsing out into a separate method
|
||||
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
|
||||
if isinstance(input, list):
|
||||
user_content = []
|
||||
for user_input in input:
|
||||
if isinstance(user_input.content, list):
|
||||
for user_input_content in user_input.content:
|
||||
if isinstance(user_input_content, OpenAIResponseInputMessageContentText):
|
||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text))
|
||||
elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage):
|
||||
if user_input_content.image_url:
|
||||
image_url = OpenAIImageURL(
|
||||
url=user_input_content.image_url, detail=user_input_content.detail
|
||||
)
|
||||
user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
else:
|
||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
|
||||
else:
|
||||
user_content = input
|
||||
messages.append(OpenAIUserMessageParam(content=user_content))
|
||||
|
||||
input = await self._prepend_previous_response(input, previous_response_id)
|
||||
messages = await _convert_response_input_to_chat_messages(input)
|
||||
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
||||
chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
|
@ -150,6 +242,7 @@ class OpenAIResponsesImpl:
|
|||
# TODO: refactor this into a separate method that handles streaming
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
# TODO: these chunk_ fields are hacky and only take the last chunk into account
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
|
@ -163,7 +256,30 @@ class OpenAIResponsesImpl:
|
|||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
assistant_message = OpenAIAssistantMessageParam(content="".join(chat_response_content))
|
||||
|
||||
# Aggregate tool call arguments across chunks, using their index as the aggregation key
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
|
||||
if response_tool_call:
|
||||
response_tool_call.function.arguments += tool_call.function.arguments
|
||||
else:
|
||||
tool_call_dict: dict[str, Any] = tool_call.model_dump()
|
||||
# Ensure we don't have any empty type field in the tool call dict.
|
||||
# The OpenAI client used by providers often returns a type=None here.
|
||||
tool_call_dict.pop("type", None)
|
||||
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
|
||||
chat_response_tool_calls[tool_call.index] = response_tool_call
|
||||
|
||||
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
chat_response = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
|
@ -181,12 +297,26 @@ class OpenAIResponsesImpl:
|
|||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
if chat_response.choices[0].message.tool_calls:
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
|
||||
for choice in chat_response.choices:
|
||||
if choice.message.tool_calls and tools:
|
||||
# Assume if the first tool is a function, all tools are functions
|
||||
if isinstance(tools[0], OpenAIResponseInputToolFunction):
|
||||
for tool_call in choice.message.tool_calls:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments=tool_call.function.arguments or "",
|
||||
call_id=tool_call.id,
|
||||
name=tool_call.function.name or "",
|
||||
id=f"fc_{uuid.uuid4()}",
|
||||
status="completed",
|
||||
)
|
||||
)
|
||||
else:
|
||||
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
|
||||
)
|
||||
else:
|
||||
output_messages.append(await _convert_chat_choice_to_response_message(choice))
|
||||
response = OpenAIResponseObject(
|
||||
created_at=chat_response.created,
|
||||
id=f"resp-{uuid.uuid4()}",
|
||||
|
@ -195,13 +325,43 @@ class OpenAIResponsesImpl:
|
|||
status="completed",
|
||||
output=output_messages,
|
||||
)
|
||||
logger.debug(f"OpenAI Responses response: {response}")
|
||||
|
||||
if store:
|
||||
# Store in kvstore
|
||||
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
input_content_item = OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[input_content],
|
||||
id=new_input_id,
|
||||
)
|
||||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
input_item_dict = input_item.model_dump()
|
||||
if "id" not in input_item_dict:
|
||||
input_item_dict["id"] = new_input_id
|
||||
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
||||
else:
|
||||
input_items_data.append(input_item)
|
||||
|
||||
input_items = OpenAIResponseInputItemList(data=input_items_data)
|
||||
prev_response = OpenAIResponsePreviousResponseWithInputItems(
|
||||
input_items=input_items,
|
||||
response=response,
|
||||
)
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
|
||||
await self.persistence_store.set(
|
||||
key=key,
|
||||
value=response.model_dump_json(),
|
||||
value=prev_response.model_dump_json(),
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
@ -221,7 +381,9 @@ class OpenAIResponsesImpl:
|
|||
chat_tools: list[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
if input_tool.type == "web_search":
|
||||
if input_tool.type == "function":
|
||||
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||
elif input_tool.type == "web_search":
|
||||
tool_name = "web_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
tool_def = ToolDefinition(
|
||||
|
@ -247,12 +409,11 @@ class OpenAIResponsesImpl:
|
|||
self,
|
||||
model_id: str,
|
||||
stream: bool,
|
||||
chat_response: OpenAIChatCompletion,
|
||||
choice: OpenAIChoice,
|
||||
messages: list[OpenAIMessageParam],
|
||||
temperature: float,
|
||||
) -> list[OpenAIResponseOutput]:
|
||||
output_messages: list[OpenAIResponseOutput] = []
|
||||
choice = chat_response.choices[0]
|
||||
|
||||
# If the choice is not an assistant message, we don't need to execute any tools
|
||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||
|
@ -262,6 +423,9 @@ class OpenAIResponsesImpl:
|
|||
if not choice.message.tool_calls:
|
||||
return output_messages
|
||||
|
||||
# Copy the messages list to avoid mutating the original list
|
||||
messages = messages.copy()
|
||||
|
||||
# Add the assistant message with tool_calls response to the messages list
|
||||
messages.append(choice.message)
|
||||
|
||||
|
@ -307,7 +471,9 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
# type cast to appease mypy
|
||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||
tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices)
|
||||
tool_final_outputs = [
|
||||
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
|
||||
]
|
||||
# TODO: Wire in annotations with URLs, titles, etc to these output messages
|
||||
output_messages.extend(tool_final_outputs)
|
||||
return output_messages
|
||||
|
|
|
@ -9,9 +9,7 @@ import logging
|
|||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
|
@ -20,15 +18,17 @@ from llama_stack.providers.utils.kvstore import KVStore
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentSessionInfo(BaseModel):
|
||||
session_id: str
|
||||
session_name: str
|
||||
class AgentSessionInfo(Session):
|
||||
# TODO: is this used anywhere?
|
||||
vector_db_id: str | None = None
|
||||
started_at: datetime
|
||||
access_attributes: AccessAttributes | None = None
|
||||
|
||||
|
||||
class AgentInfo(AgentConfig):
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
def __init__(self, agent_id: str, kvstore: KVStore):
|
||||
self.agent_id = agent_id
|
||||
|
@ -46,6 +46,7 @@ class AgentPersistence:
|
|||
session_name=name,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
access_attributes=access_attributes,
|
||||
turns=[],
|
||||
)
|
||||
|
||||
await self.kvstore.set(
|
||||
|
@ -109,7 +110,7 @@ class AgentPersistence:
|
|||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
values = await self.kvstore.range(
|
||||
values = await self.kvstore.values_in_range(
|
||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
|
@ -121,7 +122,6 @@ class AgentPersistence:
|
|||
except Exception as e:
|
||||
log.error(f"Error parsing turn: {e}")
|
||||
continue
|
||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
|
@ -170,3 +170,43 @@ class AgentPersistence:
|
|||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return int(value) if value else None
|
||||
|
||||
async def list_sessions(self) -> list[Session]:
|
||||
values = await self.kvstore.values_in_range(
|
||||
start_key=f"session:{self.agent_id}:",
|
||||
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
sessions = []
|
||||
for value in values:
|
||||
try:
|
||||
session_info = Session(**json.loads(value))
|
||||
sessions.append(session_info)
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing session info: {e}")
|
||||
continue
|
||||
return sessions
|
||||
|
||||
async def delete_session_turns(self, session_id: str) -> None:
|
||||
"""Delete all turns and their associated data for a session.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session whose turns should be deleted.
|
||||
"""
|
||||
turns = await self.get_session_turns(session_id)
|
||||
for turn in turns:
|
||||
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
|
||||
|
||||
async def delete_session(self, session_id: str) -> None:
|
||||
"""Delete a session and all its associated turns.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session does not exist.
|
||||
"""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")
|
||||
|
|
|
@ -11,9 +11,9 @@ from llama_stack.apis.common.responses import PaginatedResponse
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.pagination import paginate_records
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
@ -64,7 +64,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
# Load existing datasets from kvstore
|
||||
start_key = DATASETS_PREFIX
|
||||
end_key = f"{DATASETS_PREFIX}\xff"
|
||||
stored_datasets = await self.kvstore.range(start_key, end_key)
|
||||
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for dataset in stored_datasets:
|
||||
dataset = Dataset.model_validate_json(dataset)
|
||||
|
|
|
@ -58,7 +58,7 @@ class MetaReferenceEvalImpl(
|
|||
# Load existing benchmarks from kvstore
|
||||
start_key = EVAL_TASKS_PREFIX
|
||||
end_key = f"{EVAL_TASKS_PREFIX}\xff"
|
||||
stored_benchmarks = await self.kvstore.range(start_key, end_key)
|
||||
stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for benchmark in stored_benchmarks:
|
||||
benchmark = Benchmark.model_validate_json(benchmark)
|
||||
|
|
|
@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -86,7 +86,7 @@ class MetaReferenceInferenceImpl(
|
|||
OpenAICompletionToLlamaStackMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
|
|
|
@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator
|
|||
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionResponse,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -38,7 +38,7 @@ class SentenceTransformersInferenceImpl(
|
|||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
|
|
35
llama_stack/providers/inline/post_training/common/utils.py
Normal file
35
llama_stack/providers/inline/post_training/common/utils.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import gc
|
||||
|
||||
|
||||
def evacuate_model_from_device(model, device: str):
|
||||
"""Safely clear a model from memory and free device resources.
|
||||
This function handles the proper cleanup of a model by:
|
||||
1. Moving the model to CPU if it's on a non-CPU device
|
||||
2. Deleting the model object to free memory
|
||||
3. Running garbage collection
|
||||
4. Clearing CUDA cache if the model was on a CUDA device
|
||||
Args:
|
||||
model: The PyTorch model to clear
|
||||
device: The device type the model is currently on ('cuda', 'mps', 'cpu')
|
||||
Note:
|
||||
- For CUDA devices, this will clear the CUDA cache after moving the model to CPU
|
||||
- For MPS devices, only moves the model to CPU (no cache clearing available)
|
||||
- For CPU devices, only deletes the model object and runs garbage collection
|
||||
"""
|
||||
if device != "cpu":
|
||||
model.to("cpu")
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
if device == "cuda":
|
||||
# we need to import such that this is only imported when the method is called
|
||||
import torch
|
||||
|
||||
torch.cuda.empty_cache()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import HuggingFacePostTrainingConfig
|
||||
|
||||
# post_training api and the huggingface provider is still experimental and under heavy development
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: HuggingFacePostTrainingConfig,
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .post_training import HuggingFacePostTrainingImpl
|
||||
|
||||
impl = HuggingFacePostTrainingImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
)
|
||||
return impl
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class HuggingFacePostTrainingConfig(BaseModel):
|
||||
# Device to run training on (cuda, cpu, mps)
|
||||
device: str = "cuda"
|
||||
|
||||
# Distributed training backend if using multiple devices
|
||||
# fsdp: Fully Sharded Data Parallel
|
||||
# deepspeed: DeepSpeed ZeRO optimization
|
||||
distributed_backend: Literal["fsdp", "deepspeed"] | None = None
|
||||
|
||||
# Format for saving model checkpoints
|
||||
# full_state: Save complete model state
|
||||
# huggingface: Save in HuggingFace format (recommended for compatibility)
|
||||
checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface"
|
||||
|
||||
# Template for formatting chat inputs and outputs
|
||||
# Used to structure the conversation format for training
|
||||
chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}"
|
||||
|
||||
# Model-specific configuration parameters
|
||||
# trust_remote_code: Allow execution of custom model code
|
||||
# attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance
|
||||
model_specific_config: dict = {
|
||||
"trust_remote_code": True,
|
||||
"attn_implementation": "sdpa",
|
||||
}
|
||||
|
||||
# Maximum sequence length for training
|
||||
# Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon)
|
||||
# Longer sequences may cause memory issues on MPS devices
|
||||
max_seq_length: int = 2048
|
||||
|
||||
# Enable gradient checkpointing to reduce memory usage
|
||||
# Trades computation for memory by recomputing activations
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
# Maximum number of checkpoints to keep
|
||||
# Older checkpoints are deleted when this limit is reached
|
||||
save_total_limit: int = 3
|
||||
|
||||
# Number of training steps between logging updates
|
||||
logging_steps: int = 10
|
||||
|
||||
# Ratio of training steps used for learning rate warmup
|
||||
# Helps stabilize early training
|
||||
warmup_ratio: float = 0.1
|
||||
|
||||
# L2 regularization coefficient
|
||||
# Helps prevent overfitting
|
||||
weight_decay: float = 0.01
|
||||
|
||||
# Number of worker processes for data loading
|
||||
# Higher values can improve data loading speed but increase memory usage
|
||||
dataloader_num_workers: int = 4
|
||||
|
||||
# Whether to pin memory in data loader
|
||||
# Can improve data transfer speed to GPU but uses more memory
|
||||
dataloader_pin_memory: bool = True
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}
|
|
@ -0,0 +1,176 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.config import (
|
||||
HuggingFacePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
||||
HFFinetuningSingleDevice,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
CHECKPOINT = "checkpoint"
|
||||
RESOURCES_STATS = "resources_stats"
|
||||
|
||||
|
||||
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
|
||||
|
||||
|
||||
class HuggingFacePostTrainingImpl:
|
||||
def __init__(
|
||||
self,
|
||||
config: HuggingFacePostTrainingConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets: Datasets,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
self._scheduler = Scheduler()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await self._scheduler.shutdown()
|
||||
|
||||
@staticmethod
|
||||
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.CHECKPOINT.value,
|
||||
name=checkpoint.identifier,
|
||||
uri=checkpoint.path,
|
||||
metadata=dict(checkpoint),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
name=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
metadata=resources_stats,
|
||||
)
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: str | None = None,
|
||||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob:
|
||||
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
||||
on_log_message_cb("Starting HF finetuning")
|
||||
|
||||
recipe = HFFinetuningSingleDevice(
|
||||
job_uuid=job_uuid,
|
||||
datasetio_api=self.datasetio_api,
|
||||
datasets_api=self.datasets_api,
|
||||
)
|
||||
|
||||
resources_allocated, checkpoints = await recipe.train(
|
||||
model=model,
|
||||
output_dir=checkpoint_dir,
|
||||
job_uuid=job_uuid,
|
||||
lora_config=algorithm_config,
|
||||
config=training_config,
|
||||
provider_config=self.config,
|
||||
)
|
||||
|
||||
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
|
||||
if checkpoints:
|
||||
for checkpoint in checkpoints:
|
||||
artifact = self._checkpoint_to_artifact(checkpoint)
|
||||
on_artifact_collected_cb(artifact)
|
||||
|
||||
on_status_change_cb(SchedulerJobStatus.completed)
|
||||
on_log_message_cb("HF finetuning completed")
|
||||
|
||||
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
raise NotImplementedError("DPO alignment is not implemented yet")
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
return ListPostTrainingJobsResponse(
|
||||
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_artifacts_metadata_by_type(job, artifact_type):
|
||||
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoints(cls, job):
|
||||
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
|
||||
|
||||
@classmethod
|
||||
def _get_resources_allocated(cls, job):
|
||||
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
||||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
match job.status:
|
||||
# TODO: Add support for other statuses to API
|
||||
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
|
||||
status = JobStatus.scheduled
|
||||
case SchedulerJobStatus.running:
|
||||
status = JobStatus.in_progress
|
||||
case SchedulerJobStatus.completed:
|
||||
status = JobStatus.completed
|
||||
case SchedulerJobStatus.failed:
|
||||
status = JobStatus.failed
|
||||
case _:
|
||||
raise NotImplementedError()
|
||||
|
||||
return PostTrainingJobStatusResponse(
|
||||
job_uuid=job_uuid,
|
||||
status=status,
|
||||
scheduled_at=job.scheduled_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
checkpoints=self._get_checkpoints(job),
|
||||
resources_allocated=self._get_resources_allocated(job),
|
||||
)
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue