Merge branch 'main' into dead_code_removal

This commit is contained in:
Omar Abdelwahab 2025-10-06 13:21:36 -07:00 committed by GitHub
commit 9886520b40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
927 changed files with 171924 additions and 102933 deletions

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in # These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence, # the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist @mattf @slekkala1 * @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist @mattf @slekkala1 @franciscojavierarceo

1
.github/TRIAGERS.md vendored
View file

@ -1,2 +1 @@
# This file documents Triage members in the Llama Stack community # This file documents Triage members in the Llama Stack community
@franciscojavierarceo

View file

@ -12,6 +12,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode | | Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers | | Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks | | Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
| Pre-commit Bot | [precommit-trigger.yml](precommit-trigger.yml) | Pre-commit bot for PR |
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build | | Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
| Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project | | Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project |
| Integration Tests (Record) | [record-integration-tests.yml](record-integration-tests.yml) | Run the integration test suite from tests/integration | | Integration Tests (Record) | [record-integration-tests.yml](record-integration-tests.yml) | Run the integration test suite from tests/integration |

View file

@ -84,6 +84,8 @@ jobs:
yq eval '.server.auth.provider_config.jwks.token = "${{ env.TOKEN }}"' -i $run_dir/run.yaml yq eval '.server.auth.provider_config.jwks.token = "${{ env.TOKEN }}"' -i $run_dir/run.yaml
cat $run_dir/run.yaml cat $run_dir/run.yaml
# avoid line breaks in the server log, especially because we grep it below.
export COLUMNS=1984
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 & nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready - name: Wait for Llama Stack server to be ready

View file

@ -42,18 +42,27 @@ jobs:
run-replay-mode-tests: run-replay-mode-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.setup, matrix.python-version, matrix.client-version, matrix.suite) }} name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.config.setup, matrix.python-version, matrix.client-version, matrix.config.suite) }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
client-type: [library, server] client-type: [library, server]
# Use vllm on weekly schedule, otherwise use test-setup input (defaults to ollama)
setup: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-setup || 'ollama')) }}
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12 # Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
suite: [base, vision] # Define (setup, suite) pairs - they are always matched and cannot be independent
# Weekly schedule (Sun 1 AM): vllm+base
# Input test-setup=ollama-vision: ollama-vision+vision
# Default (including test-setup=ollama): both ollama+base and ollama-vision+vision
config: >-
${{
github.event.schedule == '1 0 * * 0'
&& fromJSON('[{"setup": "vllm", "suite": "base"}]')
|| github.event.inputs.test-setup == 'ollama-vision'
&& fromJSON('[{"setup": "ollama-vision", "suite": "vision"}]')
|| fromJSON('[{"setup": "ollama", "suite": "base"}, {"setup": "ollama-vision", "suite": "vision"}]')
}}
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -64,14 +73,14 @@ jobs:
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
client-version: ${{ matrix.client-version }} client-version: ${{ matrix.client-version }}
setup: ${{ matrix.setup }} setup: ${{ matrix.config.setup }}
suite: ${{ matrix.suite }} suite: ${{ matrix.config.suite }}
inference-mode: 'replay' inference-mode: 'replay'
- name: Run tests - name: Run tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
setup: ${{ matrix.setup }} setup: ${{ matrix.config.setup }}
inference-mode: 'replay' inference-mode: 'replay'
suite: ${{ matrix.suite }} suite: ${{ matrix.config.suite }}

227
.github/workflows/precommit-trigger.yml vendored Normal file
View file

@ -0,0 +1,227 @@
name: Pre-commit Bot
run-name: Pre-commit bot for PR #${{ github.event.issue.number }}
on:
issue_comment:
types: [created]
jobs:
pre-commit:
# Only run on pull request comments
if: github.event.issue.pull_request && contains(github.event.comment.body, '@github-actions run precommit')
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- name: Check comment author and get PR details
id: check_author
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
// Get PR details
const pr = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: context.issue.number
});
// Check if commenter has write access or is the PR author
const commenter = context.payload.comment.user.login;
const prAuthor = pr.data.user.login;
let hasPermission = false;
// Check if commenter is PR author
if (commenter === prAuthor) {
hasPermission = true;
console.log(`Comment author ${commenter} is the PR author`);
} else {
// Check if commenter has write/admin access
try {
const permission = await github.rest.repos.getCollaboratorPermissionLevel({
owner: context.repo.owner,
repo: context.repo.repo,
username: commenter
});
const level = permission.data.permission;
hasPermission = ['write', 'admin', 'maintain'].includes(level);
console.log(`Comment author ${commenter} has permission: ${level}`);
} catch (error) {
console.log(`Could not check permissions for ${commenter}: ${error.message}`);
}
}
if (!hasPermission) {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: `❌ @${commenter} You don't have permission to trigger pre-commit. Only PR authors or repository collaborators can run this command.`
});
core.setFailed(`User ${commenter} does not have permission`);
return;
}
// Save PR info for later steps
core.setOutput('pr_number', context.issue.number);
core.setOutput('pr_head_ref', pr.data.head.ref);
core.setOutput('pr_head_sha', pr.data.head.sha);
core.setOutput('pr_head_repo', pr.data.head.repo.full_name);
core.setOutput('pr_base_ref', pr.data.base.ref);
core.setOutput('is_fork', pr.data.head.repo.full_name !== context.payload.repository.full_name);
core.setOutput('authorized', 'true');
- name: React to comment
if: steps.check_author.outputs.authorized == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.reactions.createForIssueComment({
owner: context.repo.owner,
repo: context.repo.repo,
comment_id: context.payload.comment.id,
content: 'rocket'
});
- name: Comment starting
if: steps.check_author.outputs.authorized == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `⏳ Running pre-commit hooks on PR #${{ steps.check_author.outputs.pr_number }}...`
});
- name: Checkout PR branch (same-repo)
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'false'
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
ref: ${{ steps.check_author.outputs.pr_head_ref }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Checkout PR branch (fork)
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'true'
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
repository: ${{ steps.check_author.outputs.pr_head_repo }}
ref: ${{ steps.check_author.outputs.pr_head_ref }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Verify checkout
if: steps.check_author.outputs.authorized == 'true'
run: |
echo "Current SHA: $(git rev-parse HEAD)"
echo "Expected SHA: ${{ steps.check_author.outputs.pr_head_sha }}"
if [[ "$(git rev-parse HEAD)" != "${{ steps.check_author.outputs.pr_head_sha }}" ]]; then
echo "::error::Checked out SHA does not match expected SHA"
exit 1
fi
- name: Set up Python
if: steps.check_author.outputs.authorized == 'true'
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with:
python-version: '3.12'
cache: pip
cache-dependency-path: |
**/requirements*.txt
.pre-commit-config.yaml
- name: Set up Node.js
if: steps.check_author.outputs.authorized == 'true'
uses: actions/setup-node@a0853c24544627f65ddf259abe73b1d18a591444 # v5.0.0
with:
node-version: '20'
cache: 'npm'
cache-dependency-path: 'llama_stack/ui/'
- name: Install npm dependencies
if: steps.check_author.outputs.authorized == 'true'
run: npm ci
working-directory: llama_stack/ui
- name: Run pre-commit
if: steps.check_author.outputs.authorized == 'true'
id: precommit
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true
env:
SKIP: no-commit-to-branch
RUFF_OUTPUT_FORMAT: github
- name: Check for changes
if: steps.check_author.outputs.authorized == 'true'
id: changes
run: |
if ! git diff --exit-code || [ -n "$(git ls-files --others --exclude-standard)" ]; then
echo "has_changes=true" >> $GITHUB_OUTPUT
echo "Changes detected after pre-commit"
else
echo "has_changes=false" >> $GITHUB_OUTPUT
echo "No changes after pre-commit"
fi
- name: Commit and push changes
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git add -A
git commit -m "style: apply pre-commit fixes
🤖 Applied by @github-actions bot via pre-commit workflow"
# Push changes
git push origin HEAD:${{ steps.check_author.outputs.pr_head_ref }}
- name: Comment success with changes
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `✅ Pre-commit hooks completed successfully!\n\n🔧 Changes have been committed and pushed to the PR branch.`
});
- name: Comment success without changes
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'false' && steps.precommit.outcome == 'success'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `✅ Pre-commit hooks passed!\n\n✨ No changes needed - your code is already formatted correctly.`
});
- name: Comment failure
if: failure()
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `❌ Pre-commit workflow failed!\n\nPlease check the [workflow logs](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) for details.`
});

View file

@ -112,7 +112,7 @@ jobs:
fi fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint" echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then if [ "$entrypoint" != "[llama stack run /app/run.yaml]" ]; then
echo "Entrypoint is not correct" echo "Entrypoint is not correct"
exit 1 exit 1
fi fi
@ -150,7 +150,7 @@ jobs:
fi fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint" echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then if [ "$entrypoint" != "[llama stack run /app/run.yaml]" ]; then
echo "Entrypoint is not correct" echo "Entrypoint is not correct"
exit 1 exit 1
fi fi

View file

@ -24,7 +24,7 @@ jobs:
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4 # v6.7.0 uses: astral-sh/setup-uv@d0cc045d04ccac9d8b7881df0226f9e82c39688e # v6.8.0
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
activate-environment: true activate-environment: true

View file

@ -7,7 +7,7 @@
[![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain) [![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
[![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain) [![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
[**Quick Start**](https://llamastack.github.io/latest/getting_started/index.html) | [**Documentation**](https://llamastack.github.io/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack) [**Quick Start**](https://llamastack.github.io/docs/getting_started/quickstart) | [**Documentation**](https://llamastack.github.io/docs) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
### ✨🎉 Llama 4 Support 🎉✨ ### ✨🎉 Llama 4 Support 🎉✨

View file

@ -187,21 +187,21 @@ Configure telemetry behavior using environment variables:
- **`OTEL_SERVICE_NAME`**: Service name for telemetry (default: empty string) - **`OTEL_SERVICE_NAME`**: Service name for telemetry (default: empty string)
- **`TELEMETRY_SINKS`**: Comma-separated list of sinks (default: `console,sqlite`) - **`TELEMETRY_SINKS`**: Comma-separated list of sinks (default: `console,sqlite`)
## Visualization with Jaeger ### Quick Setup: Complete Telemetry Stack
The `otel_trace` sink works with any service compatible with the OpenTelemetry collector. Traces and metrics use separate endpoints but can share the same collector. Use the automated setup script to launch the complete telemetry stack (Jaeger, OpenTelemetry Collector, Prometheus, and Grafana):
### Starting Jaeger
Start a Jaeger instance with OTLP HTTP endpoint at 4318 and the Jaeger UI at 16686:
```bash ```bash
docker run --pull always --rm --name jaeger \ ./scripts/telemetry/setup_telemetry.sh
-p 16686:16686 -p 4318:4318 \
jaegertracing/jaeger:2.1.0
``` ```
Once running, you can visualize traces by navigating to [http://localhost:16686/](http://localhost:16686/). This sets up:
- **Jaeger UI**: http://localhost:16686 (traces visualization)
- **Prometheus**: http://localhost:9090 (metrics)
- **Grafana**: http://localhost:3000 (dashboards with auto-configured data sources)
- **OTEL Collector**: http://localhost:4318 (OTLP endpoint)
Once running, you can visualize traces by navigating to [Grafana](http://localhost:3000/) and login with login `admin` and password `admin`.
## Querying Metrics ## Querying Metrics

View file

@ -152,7 +152,6 @@ __all__ = ["WeatherAPI", "available_providers"]
from typing import Protocol from typing import Protocol
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
ProviderSpec, ProviderSpec,
RemoteProviderSpec, RemoteProviderSpec,
@ -166,12 +165,10 @@ def available_providers() -> list[ProviderSpec]:
api=Api.weather, api=Api.weather,
provider_type="remote::kaze", provider_type="remote::kaze",
config_class="llama_stack_provider_kaze.KazeProviderConfig", config_class="llama_stack_provider_kaze.KazeProviderConfig",
adapter=AdapterSpec( adapter_type="kaze",
adapter_type="kaze", module="llama_stack_provider_kaze",
module="llama_stack_provider_kaze", pip_packages=["llama_stack_provider_kaze"],
pip_packages=["llama_stack_provider_kaze"], config_class="llama_stack_provider_kaze.KazeProviderConfig",
config_class="llama_stack_provider_kaze.KazeProviderConfig",
),
), ),
] ]
@ -325,11 +322,10 @@ class WeatherKazeAdapter(WeatherProvider):
```yaml ```yaml
# ~/.llama/providers.d/remote/weather/kaze.yaml # ~/.llama/providers.d/remote/weather/kaze.yaml
adapter: adapter_type: kaze
adapter_type: kaze pip_packages: ["llama_stack_provider_kaze"]
pip_packages: ["llama_stack_provider_kaze"] config_class: llama_stack_provider_kaze.config.KazeProviderConfig
config_class: llama_stack_provider_kaze.config.KazeProviderConfig module: llama_stack_provider_kaze
module: llama_stack_provider_kaze
optional_api_dependencies: [] optional_api_dependencies: []
``` ```
@ -361,7 +357,7 @@ server:
8. Run the server: 8. Run the server:
```bash ```bash
python -m llama_stack.core.server.server --yaml-config ~/.llama/run-byoa.yaml llama stack run ~/.llama/run-byoa.yaml
``` ```
9. Test the API: 9. Test the API:

View file

@ -170,7 +170,7 @@ spec:
- name: llama-stack - name: llama-stack
image: localhost/llama-stack-run-k8s:latest image: localhost/llama-stack-run-k8s:latest
imagePullPolicy: IfNotPresent imagePullPolicy: IfNotPresent
command: ["python", "-m", "llama_stack.core.server.server", "--config", "/app/config.yaml"] command: ["llama", "stack", "run", "/app/config.yaml"]
ports: ports:
- containerPort: 5000 - containerPort: 5000
volumeMounts: volumeMounts:

View file

@ -509,16 +509,16 @@ server:
provider_config: provider_config:
type: "github_token" type: "github_token"
github_api_base_url: "https://api.github.com" github_api_base_url: "https://api.github.com"
access_policy: access_policy:
- permit: - permit:
principal: user-1 principal: user-1
actions: [create, read, delete] actions: [create, read, delete]
description: user-1 has full access to all resources description: user-1 has full access to all resources
- permit: - permit:
principal: user-2 principal: user-2
actions: [read] actions: [read]
resource: model::model-1 resource: model::model-1
description: user-2 has read access to model-1 only description: user-2 has read access to model-1 only
``` ```
Similarly, the following restricts access to particular kubernetes Similarly, the following restricts access to particular kubernetes

View file

@ -52,7 +52,7 @@ spec:
value: "${SAFETY_MODEL}" value: "${SAFETY_MODEL}"
- name: TAVILY_SEARCH_API_KEY - name: TAVILY_SEARCH_API_KEY
value: "${TAVILY_SEARCH_API_KEY}" value: "${TAVILY_SEARCH_API_KEY}"
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8321"] command: ["llama", "stack", "run", "/etc/config/stack_run_config.yaml", "--port", "8321"]
ports: ports:
- containerPort: 8321 - containerPort: 8321
volumeMounts: volumeMounts:

View file

@ -11,38 +11,6 @@ an example entry in your build.yaml should look like:
module: ramalama_stack module: ramalama_stack
``` ```
Additionally you can configure the `external_providers_dir` in your Llama Stack configuration. This method is in the process of being deprecated in favor of the `module` method. If using this method, the external provider directory should contain your external provider specifications:
```yaml
external_providers_dir: ~/.llama/providers.d/
```
## Directory Structure
The external providers directory should follow this structure:
```
providers.d/
remote/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
inline/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
```
Each YAML file in these directories defines a provider specification for that particular API.
## Provider Types ## Provider Types
Llama Stack supports two types of external providers: Llama Stack supports two types of external providers:
@ -50,30 +18,37 @@ Llama Stack supports two types of external providers:
1. **Remote Providers**: Providers that communicate with external services (e.g., cloud APIs) 1. **Remote Providers**: Providers that communicate with external services (e.g., cloud APIs)
2. **Inline Providers**: Providers that run locally within the Llama Stack process 2. **Inline Providers**: Providers that run locally within the Llama Stack process
### Provider Specification (Common between inline and remote providers)
- `provider_type`: The type of the provider to be installed (remote or inline). eg. `remote::ollama`
- `api`: The API for this provider, eg. `inference`
- `config_class`: The full path to the configuration class
- `module`: The Python module containing the provider implementation
- `optional_api_dependencies`: List of optional Llama Stack APIs that this provider can use
- `api_dependencies`: List of Llama Stack APIs that this provider depends on
- `provider_data_validator`: Optional validator for provider data.
- `pip_packages`: List of Python packages required by the provider
### Remote Provider Specification ### Remote Provider Specification
Remote providers are used when you need to communicate with external services. Here's an example for a custom Ollama provider: Remote providers are used when you need to communicate with external services. Here's an example for a custom Ollama provider:
```yaml ```yaml
adapter: adapter_type: custom_ollama
adapter_type: custom_ollama provider_type: "remote::ollama"
pip_packages: pip_packages:
- ollama - ollama
- aiohttp - aiohttp
config_class: llama_stack_ollama_provider.config.OllamaImplConfig config_class: llama_stack_ollama_provider.config.OllamaImplConfig
module: llama_stack_ollama_provider module: llama_stack_ollama_provider
api_dependencies: [] api_dependencies: []
optional_api_dependencies: [] optional_api_dependencies: []
``` ```
#### Adapter Configuration #### Remote Provider Configuration
The `adapter` section defines how to load and configure the provider: - `adapter_type`: A unique identifier for this adapter, eg. `ollama`
- `adapter_type`: A unique identifier for this adapter
- `pip_packages`: List of Python packages required by the provider
- `config_class`: The full path to the configuration class
- `module`: The Python module containing the provider implementation
### Inline Provider Specification ### Inline Provider Specification
@ -81,6 +56,7 @@ Inline providers run locally within the Llama Stack process. Here's an example f
```yaml ```yaml
module: llama_stack_vector_provider module: llama_stack_vector_provider
provider_type: inline::llama_stack_vector_provider
config_class: llama_stack_vector_provider.config.VectorStoreConfig config_class: llama_stack_vector_provider.config.VectorStoreConfig
pip_packages: pip_packages:
- faiss-cpu - faiss-cpu
@ -95,12 +71,6 @@ container_image: custom-vector-store:latest # optional
#### Inline Provider Fields #### Inline Provider Fields
- `module`: The Python module containing the provider implementation
- `config_class`: The full path to the configuration class
- `pip_packages`: List of Python packages required by the provider
- `api_dependencies`: List of Llama Stack APIs that this provider depends on
- `optional_api_dependencies`: List of optional Llama Stack APIs that this provider can use
- `provider_data_validator`: Optional validator for provider data
- `container_image`: Optional container image to use instead of pip packages - `container_image`: Optional container image to use instead of pip packages
## Required Fields ## Required Fields
@ -113,20 +83,17 @@ All providers must contain a `get_provider_spec` function in their `provider` mo
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
ProviderSpec, ProviderSpec,
Api, Api,
AdapterSpec, RemoteProviderSpec,
remote_provider_spec,
) )
def get_provider_spec() -> ProviderSpec: def get_provider_spec() -> ProviderSpec:
return remote_provider_spec( return RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="ramalama",
adapter_type="ramalama", pip_packages=["ramalama>=0.8.5", "pymilvus"],
pip_packages=["ramalama>=0.8.5", "pymilvus"], config_class="ramalama_stack.config.RamalamaImplConfig",
config_class="ramalama_stack.config.RamalamaImplConfig", module="ramalama_stack",
module="ramalama_stack",
),
) )
``` ```
@ -197,18 +164,16 @@ information. Execute the test for the Provider type you are developing.
If your external provider isn't being loaded: If your external provider isn't being loaded:
1. Check that `module` points to a published pip package with a top level `provider` module including `get_provider_spec`. 1. Check that `module` points to a published pip package with a top level `provider` module including `get_provider_spec`.
1. Check that the `external_providers_dir` path is correct and accessible.
2. Verify that the YAML files are properly formatted. 2. Verify that the YAML files are properly formatted.
3. Ensure all required Python packages are installed. 3. Ensure all required Python packages are installed.
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more 4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more
information using `LLAMA_STACK_LOGGING=all=debug`. information using `LLAMA_STACK_LOGGING=all=debug`.
5. Verify that the provider package is installed in your Python environment if using `external_providers_dir`.
## Examples ## Examples
### Example using `external_providers_dir`: Custom Ollama Provider ### How to create an external provider module
Here's a complete example of creating and using a custom Ollama provider: If you are creating a new external provider called `llama-stack-provider-ollama` here is how you would set up the package properly:
1. First, create the provider package: 1. First, create the provider package:
@ -230,33 +195,28 @@ requires-python = ">=3.12"
dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"] dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
``` ```
3. Create the provider specification: 3. Install the provider:
```yaml
# ~/.llama/providers.d/remote/inference/custom_ollama.yaml
adapter:
adapter_type: custom_ollama
pip_packages: ["ollama", "aiohttp"]
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
module: llama_stack_provider_ollama
api_dependencies: []
optional_api_dependencies: []
```
4. Install the provider:
```bash ```bash
uv pip install -e . uv pip install -e .
``` ```
5. Configure Llama Stack to use external providers: 4. Edit `provider.py`
```yaml provider.py must be updated to contain `get_provider_spec`. This is used by llama stack to install the provider.
external_providers_dir: ~/.llama/providers.d/
```python
def get_provider_spec() -> ProviderSpec:
return RemoteProviderSpec(
api=Api.inference,
adapter_type="llama-stack-provider-ollama",
pip_packages=["ollama", "aiohttp"],
config_class="llama_stack_provider_ollama.config.OllamaImplConfig",
module="llama_stack_provider_ollama",
)
``` ```
The provider will now be available in Llama Stack with the type `remote::custom_ollama`. 5. Implement the provider as outlined above with `get_provider_impl` or `get_adapter_impl`, etc.
### Example using `module`: ramalama-stack ### Example using `module`: ramalama-stack
@ -275,7 +235,6 @@ distribution_spec:
module: ramalama_stack==0.3.0a0 module: ramalama_stack==0.3.0a0
image_type: venv image_type: venv
image_name: null image_name: null
external_providers_dir: null
additional_pip_packages: additional_pip_packages:
- aiosqlite - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -1,4 +1,7 @@
--- ---
description: "Files
This API is used to upload documents that can be used with other Llama Stack APIs."
sidebar_label: Files sidebar_label: Files
title: Files title: Files
--- ---
@ -7,4 +10,8 @@ title: Files
## Overview ## Overview
Files
This API is used to upload documents that can be used with other Llama Stack APIs.
This section contains documentation for all available providers for the **files** API. This section contains documentation for all available providers for the **files** API.

View file

@ -1,5 +1,7 @@
--- ---
description: "Llama Stack Inference API for generating completions, chat completions, and embeddings. description: "Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
@ -12,7 +14,9 @@ title: Inference
## Overview ## Overview
Llama Stack Inference API for generating completions, chat completions, and embeddings. Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.

View file

@ -14,6 +14,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | API key for Anthropic models | | `api_key` | `str \| None` | No | | API key for Anthropic models |
## Sample Configuration ## Sample Configuration

View file

@ -21,6 +21,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Azure API key for Azure | | `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Azure API key for Azure |
| `api_base` | `<class 'pydantic.networks.HttpUrl'>` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) | | `api_base` | `<class 'pydantic.networks.HttpUrl'>` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) | | `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |

View file

@ -14,6 +14,7 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID | | `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY | | `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN | | `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |

View file

@ -14,6 +14,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API | | `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Cerebras API Key | | `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Cerebras API Key |

View file

@ -14,7 +14,8 @@ Databricks inference provider for running models on Databricks' unified analytic
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `url` | `<class 'str'>` | No | | The URL for the Databricks model serving endpoint | | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token | | `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |
## Sample Configuration ## Sample Configuration

View file

@ -14,6 +14,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | API key for Gemini models | | `api_key` | `str \| None` | No | | API key for Gemini models |
## Sample Configuration ## Sample Configuration

View file

@ -14,6 +14,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | The Groq API key | | `api_key` | `str \| None` | No | | The Groq API key |
| `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server | | `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server |

View file

@ -14,6 +14,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | The Llama API key | | `api_key` | `str \| None` | No | | The Llama API key |
| `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server | | `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |

View file

@ -14,6 +14,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM | | `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service |
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests | | `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |

View file

@ -14,6 +14,7 @@ Ollama inference provider for running local models through the Ollama runtime.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | http://localhost:11434 | | | `url` | `<class 'str'>` | No | http://localhost:11434 | |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |

View file

@ -14,6 +14,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | API key for OpenAI models | | `api_key` | `str \| None` | No | | API key for OpenAI models |
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API | | `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |

View file

@ -14,6 +14,7 @@ Passthrough inference provider for connecting to any external inference service
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint | | `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint | | `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint |

View file

@ -14,6 +14,7 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint | | `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
| `api_token` | `str \| None` | No | | The API token | | `api_token` | `str \| None` | No | | The API token |

View file

@ -14,6 +14,7 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server | | `url` | `<class 'str'>` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key |

View file

@ -14,6 +14,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | | The URL for the TGI serving endpoint | | `url` | `<class 'str'>` | No | | The URL for the TGI serving endpoint |
## Sample Configuration ## Sample Configuration

View file

@ -53,6 +53,7 @@ Available Models:
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `project` | `<class 'str'>` | No | | Google Cloud project ID for Vertex AI | | `project` | `<class 'str'>` | No | | Google Cloud project ID for Vertex AI |
| `location` | `<class 'str'>` | No | us-central1 | Google Cloud location for Vertex AI | | `location` | `<class 'str'>` | No | us-central1 | Google Cloud location for Vertex AI |

View file

@ -14,6 +14,7 @@ Remote vLLM inference provider for connecting to vLLM servers.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint | | `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. | | `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
| `api_token` | `str \| None` | No | fake | The API token | | `api_token` | `str \| None` | No | fake | The API token |

View file

@ -14,6 +14,7 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai | | `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key |
| `project_id` | `str \| None` | No | | The Project ID key | | `project_id` | `str \| None` | No | | The Project ID key |

View file

@ -1,4 +1,7 @@
--- ---
description: "Safety
OpenAI-compatible Moderations API."
sidebar_label: Safety sidebar_label: Safety
title: Safety title: Safety
--- ---
@ -7,4 +10,8 @@ title: Safety
## Overview ## Overview
Safety
OpenAI-compatible Moderations API.
This section contains documentation for all available providers for the **safety** API. This section contains documentation for all available providers for the **safety** API.

View file

@ -14,6 +14,7 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID | | `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY | | `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN | | `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |

View file

@ -50,6 +50,7 @@ from .specification import (
Document, Document,
Example, Example,
ExampleRef, ExampleRef,
ExtraBodyParameter,
MediaType, MediaType,
Operation, Operation,
Parameter, Parameter,
@ -677,6 +678,27 @@ class Generator:
# parameters passed anywhere # parameters passed anywhere
parameters = path_parameters + query_parameters parameters = path_parameters + query_parameters
# Build extra body parameters documentation
extra_body_parameters = []
for param_name, param_type, description in op.extra_body_params:
if is_type_optional(param_type):
inner_type: type = unwrap_optional_type(param_type)
required = False
else:
inner_type = param_type
required = True
# Use description from ExtraBodyField if available, otherwise from docstring
param_description = description or doc_params.get(param_name)
extra_body_param = ExtraBodyParameter(
name=param_name,
schema=self.schema_builder.classdef_to_ref(inner_type),
description=param_description,
required=required,
)
extra_body_parameters.append(extra_body_param)
webmethod = getattr(op.func_ref, "__webmethod__", None) webmethod = getattr(op.func_ref, "__webmethod__", None)
raw_bytes_request_body = False raw_bytes_request_body = False
if webmethod: if webmethod:
@ -898,6 +920,7 @@ class Generator:
deprecated=getattr(op.webmethod, "deprecated", False) deprecated=getattr(op.webmethod, "deprecated", False)
or "DEPRECATED" in op.func_name, or "DEPRECATED" in op.func_name,
security=[] if op.public else None, security=[] if op.public else None,
extraBodyParameters=extra_body_parameters if extra_body_parameters else None,
) )
def _get_api_stability_priority(self, api_level: str) -> int: def _get_api_stability_priority(self, api_level: str) -> int:

View file

@ -19,10 +19,12 @@ from llama_stack.strong_typing.inspection import get_signature
from typing import get_origin, get_args from typing import get_origin, get_args
from fastapi import UploadFile from fastapi import UploadFile
from fastapi.params import File, Form from fastapi.params import File, Form
from typing import Annotated from typing import Annotated
from llama_stack.schema_utils import ExtraBodyField
def split_prefix( def split_prefix(
s: str, sep: str, prefix: Union[str, Iterable[str]] s: str, sep: str, prefix: Union[str, Iterable[str]]
@ -89,6 +91,7 @@ class EndpointOperation:
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs. :param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
:param request_params: The parameter that corresponds to the data transmitted in the request body. :param request_params: The parameter that corresponds to the data transmitted in the request body.
:param multipart_params: Parameters that indicate multipart/form-data request body. :param multipart_params: Parameters that indicate multipart/form-data request body.
:param extra_body_params: Parameters that arrive via extra_body and are documented but not in SDK.
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress. :param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
:param response_type: The Python type of the data that is transmitted in the response body. :param response_type: The Python type of the data that is transmitted in the response body.
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT. :param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
@ -106,6 +109,7 @@ class EndpointOperation:
query_params: List[OperationParameter] query_params: List[OperationParameter]
request_params: Optional[OperationParameter] request_params: Optional[OperationParameter]
multipart_params: List[OperationParameter] multipart_params: List[OperationParameter]
extra_body_params: List[tuple[str, type, str | None]]
event_type: Optional[type] event_type: Optional[type]
response_type: type response_type: type
http_method: HTTPMethod http_method: HTTPMethod
@ -265,6 +269,7 @@ def get_endpoint_operations(
query_params = [] query_params = []
request_params = [] request_params = []
multipart_params = [] multipart_params = []
extra_body_params = []
for param_name, parameter in signature.parameters.items(): for param_name, parameter in signature.parameters.items():
param_type = _get_annotation_type(parameter.annotation, func_ref) param_type = _get_annotation_type(parameter.annotation, func_ref)
@ -279,6 +284,13 @@ def get_endpoint_operations(
f"parameter '{param_name}' in function '{func_name}' has no type annotation" f"parameter '{param_name}' in function '{func_name}' has no type annotation"
) )
# Check if this is an extra_body parameter
is_extra_body, extra_body_desc = _is_extra_body_param(param_type)
if is_extra_body:
# Store in a separate list for documentation
extra_body_params.append((param_name, param_type, extra_body_desc))
continue # Skip adding to request_params
is_multipart = _is_multipart_param(param_type) is_multipart = _is_multipart_param(param_type)
if prefix in ["get", "delete"]: if prefix in ["get", "delete"]:
@ -351,6 +363,7 @@ def get_endpoint_operations(
query_params=query_params, query_params=query_params,
request_params=request_params, request_params=request_params,
multipart_params=multipart_params, multipart_params=multipart_params,
extra_body_params=extra_body_params,
event_type=event_type, event_type=event_type,
response_type=response_type, response_type=response_type,
http_method=http_method, http_method=http_method,
@ -403,7 +416,7 @@ def get_endpoint_events(endpoint: type) -> Dict[str, type]:
def _is_multipart_param(param_type: type) -> bool: def _is_multipart_param(param_type: type) -> bool:
""" """
Check if a parameter type indicates multipart form data. Check if a parameter type indicates multipart form data.
Returns True if the type is: Returns True if the type is:
- UploadFile - UploadFile
- Annotated[UploadFile, File()] - Annotated[UploadFile, File()]
@ -413,19 +426,38 @@ def _is_multipart_param(param_type: type) -> bool:
""" """
if param_type is UploadFile: if param_type is UploadFile:
return True return True
# Check for Annotated types # Check for Annotated types
origin = get_origin(param_type) origin = get_origin(param_type)
if origin is None: if origin is None:
return False return False
if origin is Annotated: if origin is Annotated:
args = get_args(param_type) args = get_args(param_type)
if len(args) < 2: if len(args) < 2:
return False return False
# Check the annotations for File() or Form() # Check the annotations for File() or Form()
for annotation in args[1:]: for annotation in args[1:]:
if isinstance(annotation, (File, Form)): if isinstance(annotation, (File, Form)):
return True return True
return False return False
def _is_extra_body_param(param_type: type) -> tuple[bool, str | None]:
"""
Check if parameter is marked as coming from extra_body.
Returns:
(is_extra_body, description): Tuple of boolean and optional description
"""
origin = get_origin(param_type)
if origin is Annotated:
args = get_args(param_type)
for annotation in args[1:]:
if isinstance(annotation, ExtraBodyField):
return True, annotation.description
# Also check by type name for cases where import matters
if type(annotation).__name__ == 'ExtraBodyField':
return True, getattr(annotation, 'description', None)
return False, None

View file

@ -106,6 +106,15 @@ class Parameter:
example: Optional[Any] = None example: Optional[Any] = None
@dataclass
class ExtraBodyParameter:
"""Represents a parameter that arrives via extra_body in the request."""
name: str
schema: SchemaOrRef
description: Optional[str] = None
required: Optional[bool] = None
@dataclass @dataclass
class Operation: class Operation:
responses: Dict[str, Union[Response, ResponseRef]] responses: Dict[str, Union[Response, ResponseRef]]
@ -118,6 +127,7 @@ class Operation:
callbacks: Optional[Dict[str, "Callback"]] = None callbacks: Optional[Dict[str, "Callback"]] = None
security: Optional[List["SecurityRequirement"]] = None security: Optional[List["SecurityRequirement"]] = None
deprecated: Optional[bool] = None deprecated: Optional[bool] = None
extraBodyParameters: Optional[List[ExtraBodyParameter]] = None
@dataclass @dataclass

View file

@ -52,6 +52,17 @@ class Specification:
if display_name: if display_name:
tag["x-displayName"] = display_name tag["x-displayName"] = display_name
# Handle operations to rename extraBodyParameters -> x-llama-stack-extra-body-params
paths = json_doc.get("paths", {})
for path_item in paths.values():
if isinstance(path_item, dict):
for method in ["get", "post", "put", "delete", "patch"]:
operation = path_item.get(method)
if operation and isinstance(operation, dict):
extra_body_params = operation.pop("extraBodyParameters", None)
if extra_body_params:
operation["x-llama-stack-extra-body-params"] = extra_body_params
return json_doc return json_doc
def get_json_string(self, pretty_print: bool = False) -> str: def get_json_string(self, pretty_print: bool = False) -> str:

View file

@ -1443,8 +1443,8 @@
"tags": [ "tags": [
"Inference" "Inference"
], ],
"summary": "List all chat completions.", "summary": "List chat completions.",
"description": "List all chat completions.", "description": "List chat completions.",
"parameters": [ "parameters": [
{ {
"name": "after", "name": "after",
@ -1520,8 +1520,8 @@
"tags": [ "tags": [
"Inference" "Inference"
], ],
"summary": "Generate an OpenAI-compatible chat completion for the given messages using the specified model.", "summary": "Create chat completions.",
"description": "Generate an OpenAI-compatible chat completion for the given messages using the specified model.", "description": "Create chat completions.\nGenerate an OpenAI-compatible chat completion for the given messages using the specified model.",
"parameters": [], "parameters": [],
"requestBody": { "requestBody": {
"content": { "content": {
@ -1565,8 +1565,8 @@
"tags": [ "tags": [
"Inference" "Inference"
], ],
"summary": "Describe a chat completion by its ID.", "summary": "Get chat completion.",
"description": "Describe a chat completion by its ID.", "description": "Get chat completion.\nDescribe a chat completion by its ID.",
"parameters": [ "parameters": [
{ {
"name": "completion_id", "name": "completion_id",
@ -1610,8 +1610,8 @@
"tags": [ "tags": [
"Inference" "Inference"
], ],
"summary": "Generate an OpenAI-compatible completion for the given prompt using the specified model.", "summary": "Create completion.",
"description": "Generate an OpenAI-compatible completion for the given prompt using the specified model.", "description": "Create completion.\nGenerate an OpenAI-compatible completion for the given prompt using the specified model.",
"parameters": [], "parameters": [],
"requestBody": { "requestBody": {
"content": { "content": {
@ -1655,8 +1655,8 @@
"tags": [ "tags": [
"Inference" "Inference"
], ],
"summary": "Generate OpenAI-compatible embeddings for the given input using the specified model.", "summary": "Create embeddings.",
"description": "Generate OpenAI-compatible embeddings for the given input using the specified model.", "description": "Create embeddings.\nGenerate OpenAI-compatible embeddings for the given input using the specified model.",
"parameters": [], "parameters": [],
"requestBody": { "requestBody": {
"content": { "content": {
@ -1700,8 +1700,8 @@
"tags": [ "tags": [
"Files" "Files"
], ],
"summary": "Returns a list of files that belong to the user's organization.", "summary": "List files.",
"description": "Returns a list of files that belong to the user's organization.", "description": "List files.\nReturns a list of files that belong to the user's organization.",
"parameters": [ "parameters": [
{ {
"name": "after", "name": "after",
@ -1770,8 +1770,8 @@
"tags": [ "tags": [
"Files" "Files"
], ],
"summary": "Upload a file that can be used across various endpoints.", "summary": "Upload file.",
"description": "Upload a file that can be used across various endpoints.\nThe file upload should be a multipart form request with:\n- file: The File object (not file name) to be uploaded.\n- purpose: The intended purpose of the uploaded file.\n- expires_after: Optional form values describing expiration for the file.", "description": "Upload file.\nUpload a file that can be used across various endpoints.\n\nThe file upload should be a multipart form request with:\n- file: The File object (not file name) to be uploaded.\n- purpose: The intended purpose of the uploaded file.\n- expires_after: Optional form values describing expiration for the file.",
"parameters": [], "parameters": [],
"requestBody": { "requestBody": {
"content": { "content": {
@ -1831,8 +1831,8 @@
"tags": [ "tags": [
"Files" "Files"
], ],
"summary": "Returns information about a specific file.", "summary": "Retrieve file.",
"description": "Returns information about a specific file.", "description": "Retrieve file.\nReturns information about a specific file.",
"parameters": [ "parameters": [
{ {
"name": "file_id", "name": "file_id",
@ -1874,8 +1874,8 @@
"tags": [ "tags": [
"Files" "Files"
], ],
"summary": "Delete a file.", "summary": "Delete file.",
"description": "Delete a file.", "description": "Delete file.",
"parameters": [ "parameters": [
{ {
"name": "file_id", "name": "file_id",
@ -1919,8 +1919,8 @@
"tags": [ "tags": [
"Files" "Files"
], ],
"summary": "Returns the contents of the specified file.", "summary": "Retrieve file content.",
"description": "Returns the contents of the specified file.", "description": "Retrieve file content.\nReturns the contents of the specified file.",
"parameters": [ "parameters": [
{ {
"name": "file_id", "name": "file_id",
@ -1999,8 +1999,8 @@
"tags": [ "tags": [
"Safety" "Safety"
], ],
"summary": "Classifies if text and/or image inputs are potentially harmful.", "summary": "Create moderation.",
"description": "Classifies if text and/or image inputs are potentially harmful.", "description": "Create moderation.\nClassifies if text and/or image inputs are potentially harmful.",
"parameters": [], "parameters": [],
"requestBody": { "requestBody": {
"content": { "content": {
@ -2044,8 +2044,8 @@
"tags": [ "tags": [
"Agents" "Agents"
], ],
"summary": "List all OpenAI responses.", "summary": "List all responses.",
"description": "List all OpenAI responses.", "description": "List all responses.",
"parameters": [ "parameters": [
{ {
"name": "after", "name": "after",
@ -2119,8 +2119,8 @@
"tags": [ "tags": [
"Agents" "Agents"
], ],
"summary": "Create a new OpenAI response.", "summary": "Create a model response.",
"description": "Create a new OpenAI response.", "description": "Create a model response.",
"parameters": [], "parameters": [],
"requestBody": { "requestBody": {
"content": { "content": {
@ -2132,7 +2132,27 @@
}, },
"required": true "required": true
}, },
"deprecated": true "deprecated": true,
"x-llama-stack-extra-body-params": [
{
"name": "shields",
"schema": {
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ResponseShieldSpec"
}
]
}
},
"description": "List of shields to apply during response generation. Shields provide safety and content moderation.",
"required": false
}
]
} }
}, },
"/v1/openai/v1/responses/{response_id}": { "/v1/openai/v1/responses/{response_id}": {
@ -2164,8 +2184,8 @@
"tags": [ "tags": [
"Agents" "Agents"
], ],
"summary": "Retrieve an OpenAI response by its ID.", "summary": "Get a model response.",
"description": "Retrieve an OpenAI response by its ID.", "description": "Get a model response.",
"parameters": [ "parameters": [
{ {
"name": "response_id", "name": "response_id",
@ -2207,8 +2227,8 @@
"tags": [ "tags": [
"Agents" "Agents"
], ],
"summary": "Delete an OpenAI response by its ID.", "summary": "Delete a response.",
"description": "Delete an OpenAI response by its ID.", "description": "Delete a response.",
"parameters": [ "parameters": [
{ {
"name": "response_id", "name": "response_id",
@ -2252,8 +2272,8 @@
"tags": [ "tags": [
"Agents" "Agents"
], ],
"summary": "List input items for a given OpenAI response.", "summary": "List input items.",
"description": "List input items for a given OpenAI response.", "description": "List input items.",
"parameters": [ "parameters": [
{ {
"name": "response_id", "name": "response_id",
@ -9521,6 +9541,21 @@
"title": "OpenAIResponseText", "title": "OpenAIResponseText",
"description": "Text response configuration for OpenAI responses." "description": "Text response configuration for OpenAI responses."
}, },
"ResponseShieldSpec": {
"type": "object",
"properties": {
"type": {
"type": "string",
"description": "The type/identifier of the shield."
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "ResponseShieldSpec",
"description": "Specification for a shield to apply during response generation."
},
"OpenAIResponseInputTool": { "OpenAIResponseInputTool": {
"oneOf": [ "oneOf": [
{ {
@ -13331,12 +13366,13 @@
}, },
{ {
"name": "Files", "name": "Files",
"description": "" "description": "This API is used to upload documents that can be used with other Llama Stack APIs.",
"x-displayName": "Files"
}, },
{ {
"name": "Inference", "name": "Inference",
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.",
"x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings." "x-displayName": "Inference"
}, },
{ {
"name": "Models", "name": "Models",
@ -13348,7 +13384,8 @@
}, },
{ {
"name": "Safety", "name": "Safety",
"description": "" "description": "OpenAI-compatible Moderations API.",
"x-displayName": "Safety"
}, },
{ {
"name": "Telemetry", "name": "Telemetry",

View file

@ -1033,8 +1033,8 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Inference - Inference
summary: List all chat completions. summary: List chat completions.
description: List all chat completions. description: List chat completions.
parameters: parameters:
- name: after - name: after
in: query in: query
@ -1087,10 +1087,10 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Inference - Inference
summary: >- summary: Create chat completions.
Generate an OpenAI-compatible chat completion for the given messages using
the specified model.
description: >- description: >-
Create chat completions.
Generate an OpenAI-compatible chat completion for the given messages using Generate an OpenAI-compatible chat completion for the given messages using
the specified model. the specified model.
parameters: [] parameters: []
@ -1122,8 +1122,11 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Inference - Inference
summary: Describe a chat completion by its ID. summary: Get chat completion.
description: Describe a chat completion by its ID. description: >-
Get chat completion.
Describe a chat completion by its ID.
parameters: parameters:
- name: completion_id - name: completion_id
in: path in: path
@ -1153,10 +1156,10 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Inference - Inference
summary: >- summary: Create completion.
Generate an OpenAI-compatible completion for the given prompt using the specified
model.
description: >- description: >-
Create completion.
Generate an OpenAI-compatible completion for the given prompt using the specified Generate an OpenAI-compatible completion for the given prompt using the specified
model. model.
parameters: [] parameters: []
@ -1189,10 +1192,10 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Inference - Inference
summary: >- summary: Create embeddings.
Generate OpenAI-compatible embeddings for the given input using the specified
model.
description: >- description: >-
Create embeddings.
Generate OpenAI-compatible embeddings for the given input using the specified Generate OpenAI-compatible embeddings for the given input using the specified
model. model.
parameters: [] parameters: []
@ -1225,9 +1228,10 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Files - Files
summary: >- summary: List files.
Returns a list of files that belong to the user's organization.
description: >- description: >-
List files.
Returns a list of files that belong to the user's organization. Returns a list of files that belong to the user's organization.
parameters: parameters:
- name: after - name: after
@ -1285,11 +1289,13 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Files - Files
summary: >- summary: Upload file.
Upload a file that can be used across various endpoints.
description: >- description: >-
Upload file.
Upload a file that can be used across various endpoints. Upload a file that can be used across various endpoints.
The file upload should be a multipart form request with: The file upload should be a multipart form request with:
- file: The File object (not file name) to be uploaded. - file: The File object (not file name) to be uploaded.
@ -1338,9 +1344,10 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Files - Files
summary: >- summary: Retrieve file.
Returns information about a specific file.
description: >- description: >-
Retrieve file.
Returns information about a specific file. Returns information about a specific file.
parameters: parameters:
- name: file_id - name: file_id
@ -1372,8 +1379,8 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Files - Files
summary: Delete a file. summary: Delete file.
description: Delete a file. description: Delete file.
parameters: parameters:
- name: file_id - name: file_id
in: path in: path
@ -1405,9 +1412,10 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Files - Files
summary: >- summary: Retrieve file content.
Returns the contents of the specified file.
description: >- description: >-
Retrieve file content.
Returns the contents of the specified file. Returns the contents of the specified file.
parameters: parameters:
- name: file_id - name: file_id
@ -1464,9 +1472,10 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Safety - Safety
summary: >- summary: Create moderation.
Classifies if text and/or image inputs are potentially harmful.
description: >- description: >-
Create moderation.
Classifies if text and/or image inputs are potentially harmful. Classifies if text and/or image inputs are potentially harmful.
parameters: [] parameters: []
requestBody: requestBody:
@ -1497,8 +1506,8 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Agents - Agents
summary: List all OpenAI responses. summary: List all responses.
description: List all OpenAI responses. description: List all responses.
parameters: parameters:
- name: after - name: after
in: query in: query
@ -1549,8 +1558,8 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Agents - Agents
summary: Create a new OpenAI response. summary: Create a model response.
description: Create a new OpenAI response. description: Create a model response.
parameters: [] parameters: []
requestBody: requestBody:
content: content:
@ -1559,6 +1568,18 @@ paths:
$ref: '#/components/schemas/CreateOpenaiResponseRequest' $ref: '#/components/schemas/CreateOpenaiResponseRequest'
required: true required: true
deprecated: true deprecated: true
x-llama-stack-extra-body-params:
- name: shields
schema:
type: array
items:
oneOf:
- type: string
- $ref: '#/components/schemas/ResponseShieldSpec'
description: >-
List of shields to apply during response generation. Shields provide safety
and content moderation.
required: false
/v1/openai/v1/responses/{response_id}: /v1/openai/v1/responses/{response_id}:
get: get:
responses: responses:
@ -1580,8 +1601,8 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Agents - Agents
summary: Retrieve an OpenAI response by its ID. summary: Get a model response.
description: Retrieve an OpenAI response by its ID. description: Get a model response.
parameters: parameters:
- name: response_id - name: response_id
in: path in: path
@ -1611,8 +1632,8 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Agents - Agents
summary: Delete an OpenAI response by its ID. summary: Delete a response.
description: Delete an OpenAI response by its ID. description: Delete a response.
parameters: parameters:
- name: response_id - name: response_id
in: path in: path
@ -1642,10 +1663,8 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Agents - Agents
summary: >- summary: List input items.
List input items for a given OpenAI response. description: List input items.
description: >-
List input items for a given OpenAI response.
parameters: parameters:
- name: response_id - name: response_id
in: path in: path
@ -7076,6 +7095,18 @@ components:
title: OpenAIResponseText title: OpenAIResponseText
description: >- description: >-
Text response configuration for OpenAI responses. Text response configuration for OpenAI responses.
ResponseShieldSpec:
type: object
properties:
type:
type: string
description: The type/identifier of the shield.
additionalProperties: false
required:
- type
title: ResponseShieldSpec
description: >-
Specification for a shield to apply during response generation.
OpenAIResponseInputTool: OpenAIResponseInputTool:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'
@ -9987,9 +10018,16 @@ tags:
x-displayName: >- x-displayName: >-
Llama Stack Evaluation API for running evaluations on model and agent candidates. Llama Stack Evaluation API for running evaluations on model and agent candidates.
- name: Files - name: Files
description: '' description: >-
This API is used to upload documents that can be used with other Llama Stack
APIs.
x-displayName: Files
- name: Inference - name: Inference
description: >- description: >-
Llama Stack Inference API for generating completions, chat completions, and
embeddings.
This API provides the raw interface to the underlying models. Two kinds of models This API provides the raw interface to the underlying models. Two kinds of models
are supported: are supported:
@ -9997,15 +10035,14 @@ tags:
- Embedding models: these models generate embeddings to be used for semantic - Embedding models: these models generate embeddings to be used for semantic
search. search.
x-displayName: >- x-displayName: Inference
Llama Stack Inference API for generating completions, chat completions, and
embeddings.
- name: Models - name: Models
description: '' description: ''
- name: PostTraining (Coming Soon) - name: PostTraining (Coming Soon)
description: '' description: ''
- name: Safety - name: Safety
description: '' description: OpenAI-compatible Moderations API.
x-displayName: Safety
- name: Telemetry - name: Telemetry
description: '' description: ''
- name: VectorIO - name: VectorIO

Binary file not shown.

Before

Width:  |  Height:  |  Size: 196 KiB

After

Width:  |  Height:  |  Size: 604 KiB

Before After
Before After

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef from llama_stack.apis.tools import ToolDef
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
from .openai_responses import ( from .openai_responses import (
ListOpenAIResponseInputItem, ListOpenAIResponseInputItem,
@ -42,6 +42,20 @@ from .openai_responses import (
) )
@json_schema_type
class ResponseShieldSpec(BaseModel):
"""Specification for a shield to apply during response generation.
:param type: The type/identifier of the shield.
"""
type: str
# TODO: more fields to be added for shield configuration
ResponseShield = str | ResponseShieldSpec
class Attachment(BaseModel): class Attachment(BaseModel):
"""An attachment to an agent turn. """An attachment to an agent turn.
@ -783,7 +797,7 @@ class Agents(Protocol):
self, self,
response_id: str, response_id: str,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
"""Retrieve an OpenAI response by its ID. """Get a model response.
:param response_id: The ID of the OpenAI response to retrieve. :param response_id: The ID of the OpenAI response to retrieve.
:returns: An OpenAIResponseObject. :returns: An OpenAIResponseObject.
@ -805,13 +819,20 @@ class Agents(Protocol):
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None, include: list[str] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
shields: Annotated[
list[ResponseShield] | None,
ExtraBodyField(
"List of shields to apply during response generation. Shields provide safety and content moderation."
),
] = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a new OpenAI response. """Create a model response.
:param input: Input message(s) to create the response. :param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions. :param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param include: (Optional) Additional fields to include in the response. :param include: (Optional) Additional fields to include in the response.
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
:returns: An OpenAIResponseObject. :returns: An OpenAIResponseObject.
""" """
... ...
@ -825,7 +846,7 @@ class Agents(Protocol):
model: str | None = None, model: str | None = None,
order: Order | None = Order.desc, order: Order | None = Order.desc,
) -> ListOpenAIResponseObject: ) -> ListOpenAIResponseObject:
"""List all OpenAI responses. """List all responses.
:param after: The ID of the last response to return. :param after: The ID of the last response to return.
:param limit: The number of responses to return. :param limit: The number of responses to return.
@ -848,7 +869,7 @@ class Agents(Protocol):
limit: int | None = 20, limit: int | None = 20,
order: Order | None = Order.desc, order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem: ) -> ListOpenAIResponseInputItem:
"""List input items for a given OpenAI response. """List input items.
:param response_id: The ID of the response to retrieve input items for. :param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination. :param after: An item ID to list items after, used for pagination.
@ -863,7 +884,7 @@ class Agents(Protocol):
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True) @webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1) @webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
"""Delete an OpenAI response by its ID. """Delete a response.
:param response_id: The ID of the OpenAI response to delete. :param response_id: The ID of the OpenAI response to delete.
:returns: An OpenAIDeleteResponseObject :returns: An OpenAIDeleteResponseObject

View file

@ -888,6 +888,10 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
input: list[OpenAIResponseInput] input: list[OpenAIResponseInput]
def to_response_object(self) -> OpenAIResponseObject:
"""Convert to OpenAIResponseObject by excluding input field."""
return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"})
@json_schema_type @json_schema_type
class ListOpenAIResponseObject(BaseModel): class ListOpenAIResponseObject(BaseModel):

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .conversations import (
Conversation,
ConversationCreateRequest,
ConversationDeletedResource,
ConversationItem,
ConversationItemCreateRequest,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
ConversationUpdateRequest,
Metadata,
)
__all__ = [
"Conversation",
"ConversationCreateRequest",
"ConversationDeletedResource",
"ConversationItem",
"ConversationItemCreateRequest",
"ConversationItemDeletedResource",
"ConversationItemList",
"Conversations",
"ConversationUpdateRequest",
"Metadata",
]

View file

@ -0,0 +1,260 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Literal, Protocol, runtime_checkable
from openai import NOT_GIVEN
from openai._types import NotGiven
from openai.types.responses.response_includable import ResponseIncludable
from pydantic import BaseModel, Field
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
Metadata = dict[str, str]
@json_schema_type
class Conversation(BaseModel):
"""OpenAI-compatible conversation object."""
id: str = Field(..., description="The unique ID of the conversation.")
object: Literal["conversation"] = Field(
default="conversation", description="The object type, which is always conversation."
)
created_at: int = Field(
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
)
metadata: Metadata | None = Field(
default=None,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
)
items: list[dict] | None = Field(
default=None,
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
)
@json_schema_type
class ConversationMessage(BaseModel):
"""OpenAI-compatible message item for conversations."""
id: str = Field(..., description="unique identifier for this message")
content: list[dict] = Field(..., description="message content")
role: str = Field(..., description="message role")
status: str = Field(..., description="message status")
type: Literal["message"] = "message"
object: Literal["message"] = "message"
ConversationItem = Annotated[
OpenAIResponseMessage
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"),
]
register_schema(ConversationItem, name="ConversationItem")
# Using OpenAI types directly caused issues but some notes for reference:
# Note that ConversationItem is a Annotated Union of the types below:
# from openai.types.responses import *
# from openai.types.responses.response_item import *
# from openai.types.conversations import ConversationItem
# f = [
# ResponseFunctionToolCallItem,
# ResponseFunctionToolCallOutputItem,
# ResponseFileSearchToolCall,
# ResponseFunctionWebSearch,
# ImageGenerationCall,
# ResponseComputerToolCall,
# ResponseComputerToolCallOutputItem,
# ResponseReasoningItem,
# ResponseCodeInterpreterToolCall,
# LocalShellCall,
# LocalShellCallOutput,
# McpListTools,
# McpApprovalRequest,
# McpApprovalResponse,
# McpCall,
# ResponseCustomToolCall,
# ResponseCustomToolCallOutput
# ]
@json_schema_type
class ConversationCreateRequest(BaseModel):
"""Request body for creating a conversation."""
items: list[ConversationItem] | None = Field(
default=[],
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
metadata: Metadata | None = Field(
default={},
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
max_length=16,
)
@json_schema_type
class ConversationUpdateRequest(BaseModel):
"""Request body for updating a conversation."""
metadata: Metadata = Field(
...,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
)
@json_schema_type
class ConversationDeletedResource(BaseModel):
"""Response for deleted conversation."""
id: str = Field(..., description="The deleted conversation identifier")
object: str = Field(default="conversation.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted")
@json_schema_type
class ConversationItemCreateRequest(BaseModel):
"""Request body for creating conversation items."""
items: list[ConversationItem] = Field(
...,
description="Items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
@json_schema_type
class ConversationItemList(BaseModel):
"""List of conversation items with pagination."""
object: str = Field(default="list", description="Object type")
data: list[ConversationItem] = Field(..., description="List of conversation items")
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
has_more: bool = Field(default=False, description="Whether there are more items available")
@json_schema_type
class ConversationItemDeletedResource(BaseModel):
"""Response for deleted conversation item."""
id: str = Field(..., description="The deleted item identifier")
object: str = Field(default="conversation.item.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted")
@runtime_checkable
@trace_protocol
class Conversations(Protocol):
"""Protocol for conversation management operations."""
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation.
:param items: Initial items to include in the conversation context.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The created conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID.
:param conversation_id: The conversation identifier.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The updated conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The deleted conversation resource.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create items in the conversation.
:param conversation_id: The conversation identifier.
:param items: Items to include in the conversation context.
:returns: List of created items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The conversation item.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
async def list(
self,
conversation_id: str,
after: str | NotGiven = NOT_GIVEN,
include: list[ResponseIncludable] | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
) -> ConversationItemList:
"""List items in the conversation.
:param conversation_id: The conversation identifier.
:param after: An item ID to list items after, used in pagination.
:param include: Specify additional output data to include in the response.
:param limit: A limit on the number of objects to be returned (1-100, default 20).
:param order: The order to return items in (asc or desc, default desc).
:returns: List of conversation items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The deleted item resource.
"""
...

View file

@ -129,6 +129,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
tool_groups = "tool_groups" tool_groups = "tool_groups"
files = "files" files = "files"
prompts = "prompts" prompts = "prompts"
conversations = "conversations"
# built-in API # built-in API
inspect = "inspect" inspect = "inspect"

View file

@ -104,6 +104,11 @@ class OpenAIFileDeleteResponse(BaseModel):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class Files(Protocol): class Files(Protocol):
"""Files
This API is used to upload documents that can be used with other Llama Stack APIs.
"""
# OpenAI Files API Endpoints # OpenAI Files API Endpoints
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) @webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
@ -113,7 +118,8 @@ class Files(Protocol):
purpose: Annotated[OpenAIFilePurpose, Form()], purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Form()] = None, expires_after: Annotated[ExpiresAfter | None, Form()] = None,
) -> OpenAIFileObject: ) -> OpenAIFileObject:
""" """Upload file.
Upload a file that can be used across various endpoints. Upload a file that can be used across various endpoints.
The file upload should be a multipart form request with: The file upload should be a multipart form request with:
@ -137,7 +143,8 @@ class Files(Protocol):
order: Order | None = Order.desc, order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None, purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse: ) -> ListOpenAIFileResponse:
""" """List files.
Returns a list of files that belong to the user's organization. Returns a list of files that belong to the user's organization.
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list. :param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
@ -154,7 +161,8 @@ class Files(Protocol):
self, self,
file_id: str, file_id: str,
) -> OpenAIFileObject: ) -> OpenAIFileObject:
""" """Retrieve file.
Returns information about a specific file. Returns information about a specific file.
:param file_id: The ID of the file to use for this request. :param file_id: The ID of the file to use for this request.
@ -168,8 +176,7 @@ class Files(Protocol):
self, self,
file_id: str, file_id: str,
) -> OpenAIFileDeleteResponse: ) -> OpenAIFileDeleteResponse:
""" """Delete file.
Delete a file.
:param file_id: The ID of the file to use for this request. :param file_id: The ID of the file to use for this request.
:returns: An OpenAIFileDeleteResponse indicating successful deletion. :returns: An OpenAIFileDeleteResponse indicating successful deletion.
@ -182,7 +189,8 @@ class Files(Protocol):
self, self,
file_id: str, file_id: str,
) -> Response: ) -> Response:
""" """Retrieve file content.
Returns the contents of the specified file. Returns the contents of the specified file.
:param file_id: The ID of the file to use for this request. :param file_id: The ID of the file to use for this request.

View file

@ -982,45 +982,6 @@ class InferenceProvider(Protocol):
model_store: ModelStore | None = None model_store: ModelStore | None = None
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
"""Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param sampling_params: Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model.
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
.. deprecated::
Use tool_config instead.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
.. deprecated::
Use tool_config instead.
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:param tool_config: (Optional) Configuration for tool use.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
"""
...
@webmethod(route="/inference/rerank", method="POST", level=LLAMA_STACK_API_V1ALPHA) @webmethod(route="/inference/rerank", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def rerank( async def rerank(
self, self,
@ -1081,7 +1042,9 @@ class InferenceProvider(Protocol):
# for fill-in-the-middle type completion # for fill-in-the-middle type completion
suffix: str | None = None, suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
"""Generate an OpenAI-compatible completion for the given prompt using the specified model. """Create completion.
Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for. :param prompt: The prompt to generate a completion for.
@ -1138,7 +1101,9 @@ class InferenceProvider(Protocol):
top_p: float | None = None, top_p: float | None = None,
user: str | None = None, user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model. """Create chat completions.
Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation. :param messages: List of messages in the conversation.
@ -1182,7 +1147,9 @@ class InferenceProvider(Protocol):
dimensions: int | None = None, dimensions: int | None = None,
user: str | None = None, user: str | None = None,
) -> OpenAIEmbeddingsResponse: ) -> OpenAIEmbeddingsResponse:
"""Generate OpenAI-compatible embeddings for the given input using the specified model. """Create embeddings.
Generate OpenAI-compatible embeddings for the given input using the specified model.
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint. :param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings. :param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
@ -1195,7 +1162,9 @@ class InferenceProvider(Protocol):
class Inference(InferenceProvider): class Inference(InferenceProvider):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings. """Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
@ -1216,7 +1185,7 @@ class Inference(InferenceProvider):
model: str | None = None, model: str | None = None,
order: Order | None = Order.desc, order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse: ) -> ListOpenAIChatCompletionResponse:
"""List all chat completions. """List chat completions.
:param after: The ID of the last chat completion to return. :param after: The ID of the last chat completion to return.
:param limit: The maximum number of chat completions to return. :param limit: The maximum number of chat completions to return.
@ -1237,10 +1206,11 @@ class Inference(InferenceProvider):
method="GET", method="GET",
level=LLAMA_STACK_API_V1, level=LLAMA_STACK_API_V1,
) )
async def get_chat_completion( @webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
self, completion_id: str async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
) -> OpenAICompletionWithInputMessages: """Get chat completion.
"""Describe a chat completion by its ID.
Describe a chat completion by its ID.
:param completion_id: ID of the chat completion. :param completion_id: ID of the chat completion.
:returns: A OpenAICompletionWithInputMessages. :returns: A OpenAICompletionWithInputMessages.

View file

@ -58,9 +58,16 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Inspect(Protocol): class Inspect(Protocol):
"""Inspect
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
"""
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1) @webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
async def list_routes(self) -> ListRoutesResponse: async def list_routes(self) -> ListRoutesResponse:
"""List all available API routes with their methods and implementing providers. """List routes.
List all available API routes with their methods and implementing providers.
:returns: Response containing information about all available routes. :returns: Response containing information about all available routes.
""" """
@ -68,7 +75,9 @@ class Inspect(Protocol):
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1) @webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1)
async def health(self) -> HealthInfo: async def health(self) -> HealthInfo:
"""Get the current health status of the service. """Get health status.
Get the current health status of the service.
:returns: Health information indicating if the service is operational. :returns: Health information indicating if the service is operational.
""" """
@ -76,7 +85,9 @@ class Inspect(Protocol):
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1) @webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1)
async def version(self) -> VersionInfo: async def version(self) -> VersionInfo:
"""Get the version of the service. """Get version.
Get the version of the service.
:returns: Version information containing the service version number. :returns: Version information containing the service version number.
""" """

View file

@ -124,7 +124,9 @@ class Models(Protocol):
self, self,
model_id: str, model_id: str,
) -> Model: ) -> Model:
"""Get a model by its identifier. """Get model.
Get a model by its identifier.
:param model_id: The identifier of the model to get. :param model_id: The identifier of the model to get.
:returns: A Model. :returns: A Model.
@ -140,7 +142,9 @@ class Models(Protocol):
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None, model_type: ModelType | None = None,
) -> Model: ) -> Model:
"""Register a model. """Register model.
Register a model.
:param model_id: The identifier of the model to register. :param model_id: The identifier of the model to register.
:param provider_model_id: The identifier of the model in the provider. :param provider_model_id: The identifier of the model in the provider.
@ -156,7 +160,9 @@ class Models(Protocol):
self, self,
model_id: str, model_id: str,
) -> None: ) -> None:
"""Unregister a model. """Unregister model.
Unregister a model.
:param model_id: The identifier of the model to unregister. :param model_id: The identifier of the model to unregister.
""" """

View file

@ -94,7 +94,9 @@ class ListPromptsResponse(BaseModel):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class Prompts(Protocol): class Prompts(Protocol):
"""Protocol for prompt management operations.""" """Prompts
Protocol for prompt management operations."""
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1) @webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
async def list_prompts(self) -> ListPromptsResponse: async def list_prompts(self) -> ListPromptsResponse:
@ -109,7 +111,9 @@ class Prompts(Protocol):
self, self,
prompt_id: str, prompt_id: str,
) -> ListPromptsResponse: ) -> ListPromptsResponse:
"""List all versions of a specific prompt. """List prompt versions.
List all versions of a specific prompt.
:param prompt_id: The identifier of the prompt to list versions for. :param prompt_id: The identifier of the prompt to list versions for.
:returns: A ListPromptsResponse containing all versions of the prompt. :returns: A ListPromptsResponse containing all versions of the prompt.
@ -122,7 +126,9 @@ class Prompts(Protocol):
prompt_id: str, prompt_id: str,
version: int | None = None, version: int | None = None,
) -> Prompt: ) -> Prompt:
"""Get a prompt by its identifier and optional version. """Get prompt.
Get a prompt by its identifier and optional version.
:param prompt_id: The identifier of the prompt to get. :param prompt_id: The identifier of the prompt to get.
:param version: The version of the prompt to get (defaults to latest). :param version: The version of the prompt to get (defaults to latest).
@ -136,7 +142,9 @@ class Prompts(Protocol):
prompt: str, prompt: str,
variables: list[str] | None = None, variables: list[str] | None = None,
) -> Prompt: ) -> Prompt:
"""Create a new prompt. """Create prompt.
Create a new prompt.
:param prompt: The prompt text content with variable placeholders. :param prompt: The prompt text content with variable placeholders.
:param variables: List of variable names that can be used in the prompt template. :param variables: List of variable names that can be used in the prompt template.
@ -153,7 +161,9 @@ class Prompts(Protocol):
variables: list[str] | None = None, variables: list[str] | None = None,
set_as_default: bool = True, set_as_default: bool = True,
) -> Prompt: ) -> Prompt:
"""Update an existing prompt (increments version). """Update prompt.
Update an existing prompt (increments version).
:param prompt_id: The identifier of the prompt to update. :param prompt_id: The identifier of the prompt to update.
:param prompt: The updated prompt text content. :param prompt: The updated prompt text content.
@ -169,7 +179,9 @@ class Prompts(Protocol):
self, self,
prompt_id: str, prompt_id: str,
) -> None: ) -> None:
"""Delete a prompt. """Delete prompt.
Delete a prompt.
:param prompt_id: The identifier of the prompt to delete. :param prompt_id: The identifier of the prompt to delete.
""" """
@ -181,7 +193,9 @@ class Prompts(Protocol):
prompt_id: str, prompt_id: str,
version: int, version: int,
) -> Prompt: ) -> Prompt:
"""Set which version of a prompt should be the default in get_prompt (latest). """Set prompt version.
Set which version of a prompt should be the default in get_prompt (latest).
:param prompt_id: The identifier of the prompt. :param prompt_id: The identifier of the prompt.
:param version: The version to set as default. :param version: The version to set as default.

View file

@ -42,13 +42,16 @@ class ListProvidersResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Providers(Protocol): class Providers(Protocol):
""" """Providers
Providers API for inspecting, listing, and modifying providers and their configurations. Providers API for inspecting, listing, and modifying providers and their configurations.
""" """
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1) @webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
async def list_providers(self) -> ListProvidersResponse: async def list_providers(self) -> ListProvidersResponse:
"""List all available providers. """List providers.
List all available providers.
:returns: A ListProvidersResponse containing information about all providers. :returns: A ListProvidersResponse containing information about all providers.
""" """
@ -56,7 +59,9 @@ class Providers(Protocol):
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1) @webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
async def inspect_provider(self, provider_id: str) -> ProviderInfo: async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get detailed information about a specific provider. """Get provider.
Get detailed information about a specific provider.
:param provider_id: The ID of the provider to inspect. :param provider_id: The ID of the provider to inspect.
:returns: A ProviderInfo object containing the provider's details. :returns: A ProviderInfo object containing the provider's details.

View file

@ -96,6 +96,11 @@ class ShieldStore(Protocol):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class Safety(Protocol): class Safety(Protocol):
"""Safety
OpenAI-compatible Moderations API.
"""
shield_store: ShieldStore shield_store: ShieldStore
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
@ -105,7 +110,9 @@ class Safety(Protocol):
messages: list[Message], messages: list[Message],
params: dict[str, Any], params: dict[str, Any],
) -> RunShieldResponse: ) -> RunShieldResponse:
"""Run a shield. """Run shield.
Run a shield.
:param shield_id: The identifier of the shield to run. :param shield_id: The identifier of the shield to run.
:param messages: The messages to run the shield on. :param messages: The messages to run the shield on.
@ -117,7 +124,9 @@ class Safety(Protocol):
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) @webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
"""Classifies if text and/or image inputs are potentially harmful. """Create moderation.
Classifies if text and/or image inputs are potentially harmful.
:param input: Input (or inputs) to classify. :param input: Input (or inputs) to classify.
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models. Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
:param model: The content moderation model you would like to use. :param model: The content moderation model you would like to use.

View file

@ -6,11 +6,18 @@
import argparse import argparse
import os import os
import ssl
import subprocess import subprocess
from pathlib import Path from pathlib import Path
import uvicorn
import yaml
from llama_stack.cli.stack.utils import ImageType from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.core.datatypes import LoggingConfig, StackRunConfig
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars, validate_env_pair
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.log import get_logger from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -146,23 +153,7 @@ class StackRun(Subcommand):
# using the current environment packages. # using the current environment packages.
if not image_type and not image_name: if not image_type and not image_name:
logger.info("No image type or image name provided. Assuming environment packages.") logger.info("No image type or image name provided. Assuming environment packages.")
from llama_stack.core.server.server import main as server_main self._uvicorn_run(config_file, args)
# Build the server args from the current args passed to the CLI
server_args = argparse.Namespace()
for arg in vars(args):
# If this is a function, avoid passing it
# "args" contains:
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
if callable(getattr(args, arg)):
continue
if arg == "config":
server_args.config = str(config_file)
else:
setattr(server_args, arg, getattr(args, arg))
# Run the server
server_main(server_args)
else: else:
run_args = formulate_run_args(image_type, image_name) run_args = formulate_run_args(image_type, image_name)
@ -184,6 +175,76 @@ class StackRun(Subcommand):
run_command(run_args) run_command(run_args)
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
if not config_file:
self.parser.error("Config file is required")
# Set environment variables if provided
if args.env:
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
self.parser.error(f"Invalid environment variable format: {env_pair}")
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
port = args.port or config.server.port
host = config.server.host or ["::", "0.0.0.0"]
# Set the config file in environment so create_app can find it
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
uvicorn_config = {
"factory": True,
"host": host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
}
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
uvicorn_config["ssl_keyfile"] = config.server.tls_keyfile
uvicorn_config["ssl_certfile"] = config.server.tls_certfile
if config.server.tls_cafile:
uvicorn_config["ssl_ca_certs"] = config.server.tls_cafile
uvicorn_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
logger.info(f"Listening on {host}:{port}")
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
# stack trace when using Ctrl+C or kill -2 (SIGINT).
# SIGTERM (kill -15) works fine without this because Python doesn't
# have a default handler for it.
#
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
uvicorn.run("llama_stack.core.server.server:create_app", **uvicorn_config)
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
def _start_ui_development_server(self, stack_server_port: int): def _start_ui_development_server(self, stack_server_port: int):
logger.info("Attempting to start UI development server...") logger.info("Attempting to start UI development server...")
# Check if npm is available # Check if npm is available

View file

@ -324,14 +324,14 @@ fi
RUN pip uninstall -y uv RUN pip uninstall -y uv
EOF EOF
# If a run config is provided, we use the --config flag # If a run config is provided, we use the llama stack CLI
if [[ -n "$run_config" ]]; then if [[ -n "$run_config" ]]; then
add_to_container << EOF add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$RUN_CONFIG_PATH"] ENTRYPOINT ["llama", "stack", "run", "$RUN_CONFIG_PATH"]
EOF EOF
elif [[ "$distro_or_config" != *.yaml ]]; then elif [[ "$distro_or_config" != *.yaml ]]; then
add_to_container << EOF add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$distro_or_config"] ENTRYPOINT ["llama", "stack", "run", "$distro_or_config"]
EOF EOF
fi fi

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,306 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import secrets
import time
from typing import Any
from openai import NOT_GIVEN
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationDeletedResource,
ConversationItem,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
Metadata,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import (
SqliteSqlStoreConfig,
SqlStoreConfig,
sqlstore_impl,
)
logger = get_logger(name=__name__, category="openai::conversations")
class ConversationServiceConfig(BaseModel):
"""Configuration for the built-in conversation service.
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
:param policy: Access control rules
"""
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
)
policy: list[AccessRule] = []
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
"""Get the conversation service implementation."""
impl = ConversationServiceImpl(config, deps)
await impl.initialize()
return impl
class ConversationServiceImpl(Conversations):
"""Built-in conversation service implementation using AuthorizedSqlStore."""
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.policy = config.policy
base_sql_store = sqlstore_impl(config.conversations_store)
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
async def initialize(self) -> None:
"""Initialize the store and create tables."""
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
await self.sql_store.create_table(
"openai_conversations",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"created_at": ColumnType.INTEGER,
"items": ColumnType.JSON,
"metadata": ColumnType.JSON,
},
)
await self.sql_store.create_table(
"conversation_items",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"conversation_id": ColumnType.STRING,
"created_at": ColumnType.INTEGER,
"item_data": ColumnType.JSON,
},
)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation."""
random_bytes = secrets.token_bytes(24)
conversation_id = f"conv_{random_bytes.hex()}"
created_at = int(time.time())
record_data = {
"id": conversation_id,
"created_at": created_at,
"items": [],
"metadata": metadata,
}
await self.sql_store.insert(
table="openai_conversations",
data=record_data,
)
if items:
item_records = []
for item in items:
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
item_record = {
"id": item_id,
"conversation_id": conversation_id,
"created_at": created_at,
"item_data": item_dict,
}
item_records.append(item_record)
await self.sql_store.insert(table="conversation_items", data=item_records)
conversation = Conversation(
id=conversation_id,
created_at=created_at,
metadata=metadata,
object="conversation",
)
logger.info(f"Created conversation {conversation_id}")
return conversation
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID."""
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
if record is None:
raise ValueError(f"Conversation {conversation_id} not found")
return Conversation(
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID"""
await self.sql_store.update(
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
)
return await self.get_conversation(conversation_id)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID."""
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
logger.info(f"Deleted conversation {conversation_id}")
return ConversationDeletedResource(id=conversation_id)
def _validate_conversation_id(self, conversation_id: str) -> None:
"""Validate conversation ID format."""
if not conversation_id.startswith("conv_"):
raise ValueError(
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
)
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
"""Get existing item ID or generate one if missing."""
if item.id is None:
random_bytes = secrets.token_bytes(24)
if item.type == "message":
item_id = f"msg_{random_bytes.hex()}"
else:
item_id = f"item_{random_bytes.hex()}"
item_dict["id"] = item_id
return item_id
return item.id
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
"""Validate conversation ID and return the conversation if it exists."""
self._validate_conversation_id(conversation_id)
return await self.get_conversation(conversation_id)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create (add) items to a conversation."""
await self._get_validated_conversation(conversation_id)
created_items = []
created_at = int(time.time())
for item in items:
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
item_record = {
"id": item_id,
"conversation_id": conversation_id,
"created_at": created_at,
"item_data": item_dict,
}
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
try:
await self.sql_store.insert(table="conversation_items", data=item_record)
except Exception:
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_items",
data={"created_at": created_at, "item_data": item_dict},
where={"id": item_id},
)
created_items.append(item_dict)
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
# Convert created items (dicts) to proper ConversationItem types
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
return ConversationItemList(
data=response_items,
first_id=created_items[0]["id"] if created_items else None,
last_id=created_items[-1]["id"] if created_items else None,
has_more=False,
)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
# Get item from conversation_items table
record = await self.sql_store.fetch_one(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
if record is None:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
return adapter.validate_python(record["item_data"])
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
"""List items in the conversation."""
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
records = result.data
if order != NOT_GIVEN and order == "asc":
records.sort(key=lambda x: x["created_at"])
else:
records.sort(key=lambda x: x["created_at"], reverse=True)
actual_limit = 20
if limit != NOT_GIVEN and isinstance(limit, int):
actual_limit = limit
records = records[:actual_limit]
items = [record["item_data"] for record in records]
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
first_id = response_items[0].id if response_items else None
last_id = response_items[-1].id if response_items else None
return ConversationItemList(
data=response_items,
first_id=first_id,
last_id=last_id,
has_more=False,
)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
_ = await self._get_validated_conversation(conversation_id)
record = await self.sql_store.fetch_one(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
if record is None:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
await self.sql_store.delete(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
return ConversationItemDeletedResource(id=item_id)

View file

@ -475,6 +475,13 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
If not specified, a default SQLite store will be used.""", If not specified, a default SQLite store will be used.""",
) )
conversations_store: SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the conversations API.
If not specified, a default SQLite store will be used.""",
)
# registry of "resources" in the distribution # registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list) models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list) shields: list[ShieldInput] = Field(default_factory=list)

View file

@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import (
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts} INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
def stack_apis() -> list[Api]: def stack_apis() -> list[Api]:
@ -243,6 +243,7 @@ def get_external_providers_from_module(
spec = module.get_provider_spec() spec = module.get_provider_spec()
else: else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run # pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
# in the case we are building we CANNOT import this module of course because it has not been installed.
spec = ProviderSpec( spec = ProviderSpec(
api=Api(provider_api), api=Api(provider_api),
provider_type=provider.provider_type, provider_type=provider.provider_type,
@ -251,9 +252,20 @@ def get_external_providers_from_module(
config_class="", config_class="",
) )
provider_type = provider.provider_type provider_type = provider.provider_type
# in the case we are building we CANNOT import this module of course because it has not been installed. if isinstance(spec, list):
# return a partially filled out spec that the build script will populate. # optionally allow people to pass inline and remote provider specs as a returned list.
registry[Api(provider_api)][provider_type] = spec # with the old method, users could pass in directories of specs using overlapping code
# we want to ensure we preserve that flexibility in this method.
logger.info(
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
)
for provider_spec in spec:
if provider_spec.provider_type != provider.provider_type:
continue
logger.info(f"Adding {provider.provider_type} to registry")
registry[Api(provider_api)][provider.provider_type] = provider_spec
else:
registry[Api(provider_api)][provider_type] = spec
except ModuleNotFoundError as exc: except ModuleNotFoundError as exc:
raise ValueError( raise ValueError(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available" "get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"

View file

@ -374,6 +374,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
# Merge extra_json parameters (extra_body from SDK is converted to extra_json)
if hasattr(options, "extra_json") and options.extra_json:
body |= options.extra_json
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params

View file

@ -10,6 +10,7 @@ from typing import Any
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec from llama_stack.apis.datatypes import ExternalApiSpec
@ -96,6 +97,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.tool_runtime: ToolRuntime, Api.tool_runtime: ToolRuntime,
Api.files: Files, Api.files: Files,
Api.prompts: Prompts, Api.prompts: Prompts,
Api.conversations: Conversations,
} }
if external_apis: if external_apis:

View file

@ -19,7 +19,6 @@ from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
Inference, Inference,
ListOpenAIChatCompletionResponse, ListOpenAIChatCompletionResponse,
LogProbConfig,
Message, Message,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletion, OpenAIChatCompletion,
@ -34,12 +33,7 @@ from llama_stack.apis.inference import (
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
Order, Order,
ResponseFormat,
SamplingParams,
StopReason, StopReason,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
@ -193,102 +187,6 @@ class InferenceRouter(Inference):
raise ModelTypeError(model_id, model.model_type, expected_model_type) raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model return model
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = None,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id, ModelType.llm)
if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
if (
tool_prompt_format
and tool_prompt_format != tool_config.tool_prompt_format
):
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else:
params = {}
if tool_choice:
params["tool_choice"] = tool_choice
if tool_prompt_format:
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
tools = tools or []
if tool_config.tool_choice == ToolChoice.none:
tools = []
elif tool_config.tool_choice == ToolChoice.auto:
pass
elif tool_config.tool_choice == ToolChoice.required:
pass
else:
# verify tool_choice is one of the tools
tool_names = [
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
for t in tools
]
if tool_config.tool_choice not in tool_names:
raise ValueError(
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
)
params = dict(
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(
messages, tool_config.tool_prompt_format
)
if stream:
response_stream = await provider.chat_completion(**params)
return self.stream_tokens_and_compute_metrics(
response=response_stream,
prompt_tokens=prompt_tokens,
model=model,
tool_prompt_format=tool_config.tool_prompt_format,
)
response = await provider.chat_completion(**params)
metrics = await self.count_tokens_and_compute_metrics(
response=response,
prompt_tokens=prompt_tokens,
model=model,
tool_prompt_format=tool_config.tool_prompt_format,
)
# these metrics will show up in the client response.
response.metrics = (
metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
)
return response
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import argparse
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import functools import functools
@ -12,7 +11,6 @@ import inspect
import json import json
import logging # allow-direct-logging import logging # allow-direct-logging
import os import os
import ssl
import sys import sys
import traceback import traceback
import warnings import warnings
@ -35,7 +33,6 @@ from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
from llama_stack.core.access_control.access_control import AccessDeniedError from llama_stack.core.access_control.access_control import AccessDeniedError
from llama_stack.core.datatypes import ( from llama_stack.core.datatypes import (
AuthenticationRequiredError, AuthenticationRequiredError,
@ -55,7 +52,6 @@ from llama_stack.core.stack import (
Stack, Stack,
cast_image_name_to_string, cast_image_name_to_string,
replace_env_vars, replace_env_vars,
validate_env_pair,
) )
from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
@ -333,23 +329,18 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
def create_app( def create_app() -> StackApp:
config_file: str | None = None,
env_vars: list[str] | None = None,
) -> StackApp:
"""Create and configure the FastAPI application. """Create and configure the FastAPI application.
Args: This factory function reads configuration from environment variables:
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution. - LLAMA_STACK_CONFIG: Path to config file (required)
env_vars: List of environment variables in KEY=value format.
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
Returns: Returns:
Configured StackApp instance. Configured StackApp instance.
""" """
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG") config_file = os.getenv("LLAMA_STACK_CONFIG")
if config_file is None: if config_file is None:
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set") raise ValueError("LLAMA_STACK_CONFIG environment variable is required")
config_file = resolve_config_or_distro(config_file, Mode.RUN) config_file = resolve_config_or_distro(config_file, Mode.RUN)
@ -361,16 +352,6 @@ def create_app(
logger_config = LoggingConfig(**cfg) logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="core::server", config=logger_config) logger = get_logger(name=__name__, category="core::server", config=logger_config)
if env_vars:
for env_pair in env_vars:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
config = replace_env_vars(config_contents) config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config)) config = StackRunConfig(**cast_image_name_to_string(config))
@ -451,6 +432,7 @@ def create_app(
apis_to_serve.add("inspect") apis_to_serve.add("inspect")
apis_to_serve.add("providers") apis_to_serve.add("providers")
apis_to_serve.add("prompts") apis_to_serve.add("prompts")
apis_to_serve.add("conversations")
for api_str in apis_to_serve: for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)
@ -493,101 +475,6 @@ def create_app(
return app return app
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
add_config_distro_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
config_or_distro = get_config_from_args(args)
try:
app = create_app(
config_file=config_or_distro,
env_vars=args.env,
)
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
sys.exit(1)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
import uvicorn
# Configure SSL if certificates are provided
port = args.port or config.server.port
ssl_config = None
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
ssl_config = {
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
if config.server.tls_cafile:
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = config.server.host or ["::", "0.0.0.0"]
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,
"host": listen_host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
}
if ssl_config:
uvicorn_config.update(ssl_config)
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
# stack trace when using Ctrl+C or kill -2 (SIGINT).
# SIGTERM (kill -15) works fine without this because Python doesn't
# have a default handler for it.
#
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
def _log_run_config(run_config: StackRunConfig): def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed.""" """Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:") logger.info("Run configuration:")
@ -614,7 +501,3 @@ def remove_disabled_providers(obj):
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None] return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
else: else:
return obj return obj
if __name__ == "__main__":
main()

View file

@ -15,6 +15,7 @@ import yaml
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval from llama_stack.apis.eval import Eval
@ -34,6 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
@ -73,6 +75,7 @@ class LlamaStack(
RAGToolRuntime, RAGToolRuntime,
Files, Files,
Prompts, Prompts,
Conversations,
): ):
pass pass
@ -312,6 +315,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
) )
impls[Api.prompts] = prompts_impl impls[Api.prompts] = prompts_impl
conversations_impl = ConversationServiceImpl(
ConversationServiceConfig(run_config=run_config),
deps=impls,
)
impls[Api.conversations] = conversations_impl
class Stack: class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None): def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
@ -342,6 +351,8 @@ class Stack:
if Api.prompts in impls: if Api.prompts in impls:
await impls[Api.prompts].initialize() await impls[Api.prompts].initialize()
if Api.conversations in impls:
await impls[Api.conversations].initialize()
await register_resources(self.run_config, impls) await register_resources(self.run_config, impls)

View file

@ -116,7 +116,7 @@ if [[ "$env_type" == "venv" ]]; then
yaml_config_arg="" yaml_config_arg=""
fi fi
$PYTHON_BINARY -m llama_stack.core.server.server \ llama stack run \
$yaml_config_arg \ $yaml_config_arg \
--port "$port" \ --port "$port" \
$env_vars \ $env_vars \

View file

@ -128,7 +128,7 @@ def strip_rich_markup(text):
class CustomRichHandler(RichHandler): class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs["console"] = Console(width=150) kwargs["console"] = Console()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def emit(self, record): def emit(self, record):

View file

@ -9,7 +9,7 @@ from pathlib import Path
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(__name__, "tokenizer_utils") logger = get_logger(__name__, "models")
def load_bpe_file(model_path: Path) -> dict[bytes, int]: def load_bpe_file(model_path: Path) -> dict[bytes, int]:

View file

@ -329,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None, include: list[str] | None = None,
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
shields: list | None = None,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response( return await self.openai_responses_impl.create_openai_response(
input, input,
@ -342,6 +343,7 @@ class MetaReferenceAgentsImpl(Agents):
tools, tools,
include, include,
max_infer_iters, max_infer_iters,
shields,
) )
async def list_openai_responses( async def list_openai_responses(

View file

@ -8,7 +8,7 @@ import time
import uuid import uuid
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from pydantic import BaseModel from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.agents import Order from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import ( from llama_stack.apis.agents.openai_responses import (
@ -26,12 +26,16 @@ from llama_stack.apis.agents.openai_responses import (
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
OpenAIMessageParam,
OpenAISystemMessageParam, OpenAISystemMessageParam,
) )
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import ResponsesStore from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from .streaming import StreamingResponseOrchestrator from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor from .tool_executor import ToolExecutor
@ -72,26 +76,48 @@ class OpenAIResponsesImpl:
async def _prepend_previous_response( async def _prepend_previous_response(
self, self,
input: str | list[OpenAIResponseInput], input: str | list[OpenAIResponseInput],
previous_response_id: str | None = None, previous_response: _OpenAIResponseObjectWithInputAndMessages,
): ):
new_input_items = previous_response.input.copy()
new_input_items.extend(previous_response.output)
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
return new_input_items
async def _process_input_with_previous_response(
self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None,
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
"""Process input with optional previous response context.
Returns:
tuple: (all_input for storage, messages for chat completion)
"""
if previous_response_id: if previous_response_id:
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) previous_response: _OpenAIResponseObjectWithInputAndMessages = (
await self.responses_store.get_response_object(previous_response_id)
)
all_input = await self._prepend_previous_response(input, previous_response)
# previous response input items if previous_response.messages:
new_input_items = previous_response_with_input.input # Use stored messages directly and convert only new input
message_adapter = TypeAdapter(list[OpenAIMessageParam])
# previous response output items messages = message_adapter.validate_python(previous_response.messages)
new_input_items.extend(previous_response_with_input.output) new_messages = await convert_response_input_to_chat_messages(input)
messages.extend(new_messages)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else: else:
new_input_items.extend(input) # Backward compatibility: reconstruct from inputs
messages = await convert_response_input_to_chat_messages(all_input)
else:
all_input = input
messages = await convert_response_input_to_chat_messages(input)
input = new_input_items return all_input, messages
return input
async def _prepend_instructions(self, messages, instructions): async def _prepend_instructions(self, messages, instructions):
if instructions: if instructions:
@ -102,7 +128,7 @@ class OpenAIResponsesImpl:
response_id: str, response_id: str,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
response_with_input = await self.responses_store.get_response_object(response_id) response_with_input = await self.responses_store.get_response_object(response_id)
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"}) return response_with_input.to_response_object()
async def list_openai_responses( async def list_openai_responses(
self, self,
@ -138,6 +164,7 @@ class OpenAIResponsesImpl:
self, self,
response: OpenAIResponseObject, response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput], input: str | list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None: ) -> None:
new_input_id = f"msg_{uuid.uuid4()}" new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str): if isinstance(input, str):
@ -165,6 +192,7 @@ class OpenAIResponsesImpl:
await self.responses_store.store_response_object( await self.responses_store.store_response_object(
response_object=response, response_object=response,
input=input_items_data, input=input_items_data,
messages=messages,
) )
async def create_openai_response( async def create_openai_response(
@ -180,10 +208,15 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None, include: list[str] | None = None,
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
shields: list | None = None,
): ):
stream = bool(stream) stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
# Shields parameter received via extra_body - not yet implemented
if shields is not None:
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
stream_gen = self._create_streaming_response( stream_gen = self._create_streaming_response(
input=input, input=input,
model=model, model=model,
@ -224,8 +257,7 @@ class OpenAIResponsesImpl:
max_infer_iters: int | None = 10, max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]: ) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing # Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id) all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
messages = await convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions) await self._prepend_instructions(messages, instructions)
# Structured outputs # Structured outputs
@ -265,7 +297,8 @@ class OpenAIResponsesImpl:
if store and final_response: if store and final_response:
await self._store_response( await self._store_response(
response=final_response, response=final_response,
input=input, input=all_input,
messages=orchestrator.final_messages,
) )
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:

View file

@ -43,6 +43,7 @@ from llama_stack.apis.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionToolCall, OpenAIChatCompletionToolCall,
OpenAIChoice, OpenAIChoice,
OpenAIMessageParam,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -94,6 +95,8 @@ class StreamingResponseOrchestrator:
self.sequence_number = 0 self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing # Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {} self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages # Initialize output messages
@ -183,6 +186,8 @@ class StreamingResponseOrchestrator:
messages = next_turn_messages messages = next_turn_messages
self.final_messages = messages.copy() + [current_response.choices[0].message]
# Create final response # Create final response
final_response = OpenAIResponseObject( final_response = OpenAIResponseObject(
created_at=self.created_at, created_at=self.created_at,

View file

@ -5,37 +5,17 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import os from collections.abc import AsyncIterator
import sys from typing import Any
from collections.abc import AsyncGenerator
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
InferenceProvider, InferenceProvider,
LogProbConfig, )
Message, from llama_stack.apis.inference.inference import (
ResponseFormat, OpenAIChatCompletion,
SamplingParams, OpenAIChatCompletionChunk,
StopReason, OpenAIMessageParam,
TokenLogProbs, OpenAIResponseFormatParam,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
UserMessage,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -53,13 +33,6 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
convert_request_to_raw,
)
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generators import LlamaGenerator from .generators import LlamaGenerator
@ -76,7 +49,6 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
class MetaReferenceInferenceImpl( class MetaReferenceInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
InferenceProvider, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,
@ -161,10 +133,10 @@ class MetaReferenceInferenceImpl(
self.llama_model = llama_model self.llama_model = llama_model
log.info("Warming up...") log.info("Warming up...")
await self.chat_completion( await self.openai_chat_completion(
model_id=model_id, model=model_id,
messages=[UserMessage(content="Hi how are you?")], messages=[{"role": "user", "content": "Hi how are you?"}],
sampling_params=SamplingParams(max_tokens=20), max_tokens=20,
) )
log.info("Warmed up!") log.info("Warmed up!")
@ -176,242 +148,30 @@ class MetaReferenceInferenceImpl(
elif request.model != self.model_id: elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}") raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def chat_completion( async def openai_chat_completion(
self, self,
model_id: str, model: str,
messages: list[Message], messages: list[OpenAIMessageParam],
sampling_params: SamplingParams | None = None, frequency_penalty: float | None = None,
response_format: ResponseFormat | None = None, function_call: str | dict[str, Any] | None = None,
tools: list[ToolDefinition] | None = None, functions: list[dict[str, Any]] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto, logit_bias: dict[str, float] | None = None,
tool_prompt_format: ToolPromptFormat | None = None, logprobs: bool | None = None,
stream: bool | None = False, max_completion_tokens: int | None = None,
logprobs: LogProbConfig | None = None, max_tokens: int | None = None,
tool_config: ToolConfig | None = None, n: int | None = None,
) -> AsyncGenerator: parallel_tool_calls: bool | None = None,
if sampling_params is None: presence_penalty: float | None = None,
sampling_params = SamplingParams() response_format: OpenAIResponseFormatParam | None = None,
if logprobs: seed: int | None = None,
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" stop: str | list[str] | None = None,
stream: bool | None = None,
# wrapper request to make it easier to pass around (internal only, not exposed to API) stream_options: dict[str, Any] | None = None,
request = ChatCompletionRequest( temperature: float | None = None,
model=model_id, tool_choice: str | dict[str, Any] | None = None,
messages=messages, tools: list[dict[str, Any]] | None = None,
sampling_params=sampling_params, top_logprobs: int | None = None,
tools=tools or [], top_p: float | None = None,
response_format=response_format, user: str | None = None,
stream=stream, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logprobs=logprobs, raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream:
return self._stream_chat_completion(request)
else:
results = await self._nonstream_chat_completion([request])
return results[0]
async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest]
) -> list[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
states = [ItemState() for _ in request_batch]
for token_results in self.generator.chat_completion(request_batch):
first = token_results[0]
if not first.finished and not first.ignore_token:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
results = []
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
results.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=state.logprobs if first_request.logprobs else None,
)
)
return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
tokens = []
logprobs = []
stop_reason = None
ipython = False
for token_results in self.generator.chat_completion([request]):
token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
tool_call=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = TextDelta(text=text)
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x

View file

@ -4,21 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator from collections.abc import AsyncIterator
from typing import Any from typing import Any
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
InferenceProvider, InferenceProvider,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import OpenAICompletion from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
@ -69,21 +67,6 @@ class SentenceTransformersInferenceImpl(
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def openai_completion( async def openai_completion(
self, self,
# Standard OpenAI completion parameters # Standard OpenAI completion parameters
@ -110,6 +93,32 @@ class SentenceTransformersInferenceImpl(
# for fill-in-the-middle type completion # for fill-in-the-middle type completion
suffix: str | None = None, suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError( raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
"OpenAI completion not supported by sentence transformers provider"
) async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")

View file

@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter_type="cerebras", adapter_type="cerebras",
provider_type="remote::cerebras", provider_type="remote::cerebras",
pip_packages=[ pip_packages=[],
"cerebras_cloud_sdk",
],
module="llama_stack.providers.remote.inference.cerebras", module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.", description="Cerebras inference provider for running models on Cerebras Cloud platform.",
@ -169,7 +167,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter_type="openai", adapter_type="openai",
provider_type="remote::openai", provider_type="remote::openai",
pip_packages=["litellm"], pip_packages=[],
module="llama_stack.providers.remote.inference.openai", module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig", config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter_type="anthropic", adapter_type="anthropic",
provider_type="remote::anthropic", provider_type="remote::anthropic",
pip_packages=["litellm"], pip_packages=["anthropic"],
module="llama_stack.providers.remote.inference.anthropic", module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig", config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
@ -189,9 +187,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter_type="gemini", adapter_type="gemini",
provider_type="remote::gemini", provider_type="remote::gemini",
pip_packages=[ pip_packages=[],
"litellm",
],
module="llama_stack.providers.remote.inference.gemini", module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
@ -202,7 +198,6 @@ def available_providers() -> list[ProviderSpec]:
adapter_type="vertexai", adapter_type="vertexai",
provider_type="remote::vertexai", provider_type="remote::vertexai",
pip_packages=[ pip_packages=[
"litellm",
"google-cloud-aiplatform", "google-cloud-aiplatform",
], ],
module="llama_stack.providers.remote.inference.vertexai", module="llama_stack.providers.remote.inference.vertexai",
@ -233,9 +228,7 @@ Available Models:
api=Api.inference, api=Api.inference,
adapter_type="groq", adapter_type="groq",
provider_type="remote::groq", provider_type="remote::groq",
pip_packages=[ pip_packages=[],
"litellm",
],
module="llama_stack.providers.remote.inference.groq", module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig", config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
@ -245,7 +238,7 @@ Available Models:
api=Api.inference, api=Api.inference,
adapter_type="llama-openai-compat", adapter_type="llama-openai-compat",
provider_type="remote::llama-openai-compat", provider_type="remote::llama-openai-compat",
pip_packages=["litellm"], pip_packages=[],
module="llama_stack.providers.remote.inference.llama_openai_compat", module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig", config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
@ -255,9 +248,7 @@ Available Models:
api=Api.inference, api=Api.inference,
adapter_type="sambanova", adapter_type="sambanova",
provider_type="remote::sambanova", provider_type="remote::sambanova",
pip_packages=[ pip_packages=[],
"litellm",
],
module="llama_stack.providers.remote.inference.sambanova", module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
@ -287,7 +278,7 @@ Available Models:
api=Api.inference, api=Api.inference,
provider_type="remote::azure", provider_type="remote::azure",
adapter_type="azure", adapter_type="azure",
pip_packages=["litellm"], pip_packages=[],
module="llama_stack.providers.remote.inference.azure", module="llama_stack.providers.remote.inference.azure",
config_class="llama_stack.providers.remote.inference.azure.AzureConfig", config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",

View file

@ -500,7 +500,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
api=Api.vector_io, api=Api.vector_io,
adapter_type="weaviate", adapter_type="weaviate",
provider_type="remote::weaviate", provider_type="remote::weaviate",
pip_packages=["weaviate-client"], pip_packages=["weaviate-client>=4.16.5"],
module="llama_stack.providers.remote.vector_io.weaviate", module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",

View file

@ -10,6 +10,6 @@ from .config import AnthropicConfig
async def get_adapter_impl(config: AnthropicConfig, _deps): async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter from .anthropic import AnthropicInferenceAdapter
impl = AnthropicInferenceAdapter(config) impl = AnthropicInferenceAdapter(config=config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,13 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from collections.abc import Iterable
from anthropic import AsyncAnthropic
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AnthropicConfig from .config import AnthropicConfig
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): class AnthropicInferenceAdapter(OpenAIMixin):
config: AnthropicConfig
provider_data_api_key_field: str = "anthropic_api_key"
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings # source: https://docs.claude.com/en/docs/build-with-claude/embeddings
# TODO: add support for voyageai, which is where these models are hosted # TODO: add support for voyageai, which is where these models are hosted
# embedding_model_metadata = { # embedding_model_metadata = {
@ -23,22 +29,11 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000}, # "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# } # }
def __init__(self, config: AnthropicConfig) -> None: def get_api_key(self) -> str:
LiteLLMOpenAIMixin.__init__( return self.config.api_key or ""
self,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",
)
self.config = config
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self): def get_base_url(self):
return "https://api.anthropic.com/v1" return "https://api.anthropic.com/v1"
async def list_provider_model_ids(self) -> Iterable[str]:
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class AnthropicProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class AnthropicConfig(BaseModel): class AnthropicConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field( api_key: str | None = Field(
default=None, default=None,
description="API key for Anthropic models", description="API key for Anthropic models",

View file

@ -10,6 +10,6 @@ from .config import AzureConfig
async def get_adapter_impl(config: AzureConfig, _deps): async def get_adapter_impl(config: AzureConfig, _deps):
from .azure import AzureInferenceAdapter from .azure import AzureInferenceAdapter
impl = AzureInferenceAdapter(config) impl = AzureInferenceAdapter(config=config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,31 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any
from urllib.parse import urljoin from urllib.parse import urljoin
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig from .config import AzureConfig
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): class AzureInferenceAdapter(OpenAIMixin):
def __init__(self, config: AzureConfig) -> None: config: AzureConfig
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="azure",
api_key_from_config=config.api_key.get_secret_value(),
provider_data_api_key_field="azure_api_key",
openai_compat_api_base=str(config.api_base),
)
self.config = config
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin provider_data_api_key_field: str = "azure_api_key"
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
def get_base_url(self) -> str: def get_base_url(self) -> str:
""" """
@ -37,26 +26,3 @@ class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Returns the Azure API base URL from the configuration. Returns the Azure API base URL from the configuration.
""" """
return urljoin(str(self.config.api_base), "/openai/v1") return urljoin(str(self.config.api_base), "/openai/v1")
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add Azure specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "azure_api_key", None):
params["api_key"] = provider_data.azure_api_key
if getattr(provider_data, "azure_api_base", None):
params["api_base"] = provider_data.azure_api_base
if getattr(provider_data, "azure_api_version", None):
params["api_version"] = provider_data.azure_api_version
if getattr(provider_data, "azure_api_type", None):
params["api_type"] = provider_data.azure_api_type
else:
params["api_key"] = self.config.api_key.get_secret_value()
params["api_base"] = str(self.config.api_base)
params["api_version"] = self.config.api_version
params["api_type"] = self.config.api_type
return params

View file

@ -9,6 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field, HttpUrl, SecretStr from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -30,7 +31,7 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class AzureConfig(BaseModel): class AzureConfig(RemoteInferenceProviderConfig):
api_key: SecretStr = Field( api_key: SecretStr = Field(
description="Azure API key for Azure", description="Azure API key for Azure",
) )

View file

@ -5,39 +5,30 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncIterator
from typing import Any from typing import Any
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
Inference, Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import OpenAICompletion from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_strategy_options, get_sampling_strategy_options,
process_chat_completion_response,
process_chat_completion_stream_response,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
@ -86,7 +77,6 @@ def _to_inference_profile_id(model_id: str, region: str = None) -> str:
class BedrockInferenceAdapter( class BedrockInferenceAdapter(
ModelRegistryHelper, ModelRegistryHelper,
Inference, Inference,
OpenAIChatCompletionToLlamaStackMixin,
): ):
def __init__(self, config: BedrockConfig) -> None: def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
@ -106,71 +96,6 @@ class BedrockInferenceAdapter(
if self._client is not None: if self._client is not None:
self._client.close() self._client.close()
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model(**params)
chunk = next(res["body"])
result = json.loads(chunk.decode("utf-8"))
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"],
text=result["generation"],
)
response = OpenAICompatCompletionResponse(choices=[choice])
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model_with_response_stream(**params)
event_stream = res["body"]
async def _generate_and_convert_to_openai_compat():
for chunk in event_stream:
chunk = chunk["chunk"]["bytes"]
result = json.loads(chunk.decode("utf-8"))
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"],
text=result["generation"],
)
yield OpenAICompatCompletionResponse(choices=[choice])
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict: async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
bedrock_model = request.model bedrock_model = request.model
@ -235,3 +160,31 @@ class BedrockInferenceAdapter(
suffix: str | None = None, suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider") raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")

View file

@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
impl = CerebrasInferenceAdapter(config) impl = CerebrasInferenceAdapter(config=config)
await impl.initialize() await impl.initialize()

View file

@ -4,53 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator
from urllib.parse import urljoin from urllib.parse import urljoin
from cerebras.cloud.sdk import AsyncCerebras from llama_stack.apis.inference import OpenAIEmbeddingsResponse
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import CerebrasImplConfig from .config import CerebrasImplConfig
class CerebrasInferenceAdapter( class CerebrasInferenceAdapter(OpenAIMixin):
OpenAIMixin, config: CerebrasImplConfig
ModelRegistryHelper,
Inference,
):
def __init__(self, config: CerebrasImplConfig) -> None:
self.config = config
# TODO: make this use provider data, etc. like other providers
self._cerebras_client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
)
def get_api_key(self) -> str: def get_api_key(self) -> str:
return self.config.api_key.get_secret_value() return self.config.api_key.get_secret_value()
@ -58,86 +21,6 @@ class CerebrasInferenceAdapter(
def get_base_url(self) -> str: def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1") return urljoin(self.config.base_url, "v1")
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self._cerebras_client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self._cerebras_client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest) -> dict:
if request.sampling_params and isinstance(
request.sampling_params.strategy, TopKSamplingStrategy
):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model)
)
else:
raise ValueError(f"Unknown request type {type(request)}")
return {
"model": request.model,
"prompt": prompt,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -7,21 +7,22 @@
import os import os
from typing import Any from typing import Any
from pydantic import BaseModel, Field, SecretStr from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai" DEFAULT_BASE_URL = "https://api.cerebras.ai"
@json_schema_type @json_schema_type
class CerebrasImplConfig(BaseModel): class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field( base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL), default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API", description="Base URL for the Cerebras API",
) )
api_key: SecretStr = Field( api_key: SecretStr = Field(
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
description="Cerebras API Key", description="Cerebras API Key",
) )

View file

@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter from .databricks import DatabricksInferenceAdapter
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config) impl = DatabricksInferenceAdapter(config=config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -6,19 +6,20 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field, SecretStr from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class DatabricksImplConfig(BaseModel): class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str = Field( url: str | None = Field(
default=None, default=None,
description="The URL for the Databricks model serving endpoint", description="The URL for the Databricks model serving endpoint",
) )
api_token: SecretStr = Field( api_token: SecretStr = Field(
default=SecretStr(None), default=SecretStr(None), # type: ignore[arg-type]
description="The Databricks API token", description="The Databricks API token",
) )

View file

@ -4,27 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator from collections.abc import Iterable
from typing import Any from typing import Any
from databricks.sdk import WorkspaceClient from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import ( from llama_stack.apis.inference import OpenAICompletion
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
Inference,
LogProbConfig,
Message,
Model,
OpenAICompletion,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -33,30 +18,31 @@ from .config import DatabricksImplConfig
logger = get_logger(name=__name__, category="inference::databricks") logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter( class DatabricksInferenceAdapter(OpenAIMixin):
OpenAIMixin, config: DatabricksImplConfig
Inference,
):
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models # source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata = { embedding_model_metadata: dict[str, dict[str, int]] = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192}, "databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512}, "databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
} }
def __init__(self, config: DatabricksImplConfig) -> None:
self.config = config
def get_api_key(self) -> str: def get_api_key(self) -> str:
return self.config.api_token.get_secret_value() return self.config.api_token.get_secret_value()
def get_base_url(self) -> str: def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints" return f"{self.config.url}/serving-endpoints"
async def initialize(self) -> None: async def list_provider_model_ids(self) -> Iterable[str]:
return return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]
async def shutdown(self) -> None: async def should_refresh_models(self) -> bool:
pass return False
async def openai_completion( async def openai_completion(
self, self,
@ -82,47 +68,3 @@ class DatabricksInferenceAdapter(
suffix: str | None = None, suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError() raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
raise NotImplementedError()
async def list_models(self) -> list[Model] | None:
self._model_cache = {} # from OpenAIMixin
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
endpoints = ws_client.serving_endpoints.list()
for endpoint in endpoints:
model = Model(
provider_id=self.__provider_id__,
provider_resource_id=endpoint.name,
identifier=endpoint.name,
)
if endpoint.task == "llm/v1/chat":
model.model_type = ModelType.llm # this is redundant, but informative
elif endpoint.task == "llm/v1/embeddings":
if endpoint.name not in self.embedding_model_metadata:
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
continue
model.model_type = ModelType.embedding
model.metadata = self.embedding_model_metadata[endpoint.name]
else:
logger.warning(f"Unknown model type, skipping: {endpoint}")
continue
self._model_cache[endpoint.name] = model
return list(self._model_cache.values())
async def should_refresh_models(self) -> bool:
return False

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: FireworksImplConfig, _deps):
from .fireworks import FireworksInferenceAdapter from .fireworks import FireworksInferenceAdapter
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}" assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
impl = FireworksInferenceAdapter(config) impl = FireworksInferenceAdapter(config=config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -4,195 +4,27 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator
from fireworks.client import Fireworks
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
request_has_media,
)
from .config import FireworksImplConfig from .config import FireworksImplConfig
logger = get_logger(name=__name__, category="inference::fireworks") logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData): class FireworksInferenceAdapter(OpenAIMixin):
embedding_model_metadata = { config: FireworksImplConfig
embedding_model_metadata: dict[str, dict[str, int]] = {
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192}, "nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960}, "accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
} }
def __init__(self, config: FireworksImplConfig) -> None: provider_data_api_key_field: str = "fireworks_api_key"
ModelRegistryHelper.__init__(self)
self.config = config
self.allowed_models = config.allowed_models
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def get_api_key(self) -> str: def get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value]
if config_api_key:
return config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-Provider-Data as { "fireworks_api_key": <your api key>}'
)
return provider_data.fireworks_api_key
def get_base_url(self) -> str: def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1" return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks:
fireworks_api_key = self.get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
"""Remove BOS token as Fireworks automatically prepends it"""
if prompt.startswith("<|begin_of_text|>"):
return prompt[len("<|begin_of_text|>") :]
return prompt
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
if "messages" in params:
stream = self._get_client().chat.completions.acreate(**params)
else:
stream = self._get_client().completion.acreate(**params)
async for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _build_options(
self,
sampling_params: SamplingParams | None,
fmt: ResponseFormat | None,
logprobs: LogProbConfig | None,
) -> dict:
options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
options["logprobs"] = logprobs.top_k
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
raise ValueError("Required range: 0 < top_k < 5")
return options
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
# TODO: tools are never added to the request, so we need to add them here
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
# Fireworks always prepends with BOS
if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
params = {
"model": request.model,
**input_dict,
"stream": bool(request.stream),
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logger.debug(f"params to fireworks: {params}")
return params

View file

@ -10,6 +10,6 @@ from .config import GeminiConfig
async def get_adapter_impl(config: GeminiConfig, _deps): async def get_adapter_impl(config: GeminiConfig, _deps):
from .gemini import GeminiInferenceAdapter from .gemini import GeminiInferenceAdapter
impl = GeminiInferenceAdapter(config) impl = GeminiInferenceAdapter(config=config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class GeminiProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class GeminiConfig(BaseModel): class GeminiConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field( api_key: str | None = Field(
default=None, default=None,
description="API key for Gemini models", description="API key for Gemini models",

View file

@ -4,33 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import GeminiConfig from .config import GeminiConfig
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): class GeminiInferenceAdapter(OpenAIMixin):
embedding_model_metadata = { config: GeminiConfig
provider_data_api_key_field: str = "gemini_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048}, "text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
} }
def __init__(self, config: GeminiConfig) -> None: def get_api_key(self) -> str:
LiteLLMOpenAIMixin.__init__( return self.config.api_key or ""
self,
litellm_provider_name="gemini",
api_key_from_config=config.api_key,
provider_data_api_key_field="gemini_api_key",
)
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self): def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/" return "https://generativelanguage.googleapis.com/v1beta/openai/"
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()

View file

@ -11,5 +11,5 @@ async def get_adapter_impl(config: GroqConfig, _deps):
# import dynamically so the import is used only when it is needed # import dynamically so the import is used only when it is needed
from .groq import GroqInferenceAdapter from .groq import GroqInferenceAdapter
adapter = GroqInferenceAdapter(config) adapter = GroqInferenceAdapter(config=config)
return adapter return adapter

Some files were not shown because too many files have changed in this diff Show more