Merge branch 'main' into dead_code_removal

This commit is contained in:
Omar Abdelwahab 2025-10-06 13:21:36 -07:00 committed by GitHub
commit 9886520b40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
927 changed files with 171924 additions and 102933 deletions

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist @mattf @slekkala1
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist @mattf @slekkala1 @franciscojavierarceo

1
.github/TRIAGERS.md vendored
View file

@ -1,2 +1 @@
# This file documents Triage members in the Llama Stack community
@franciscojavierarceo

View file

@ -12,6 +12,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
| Pre-commit Bot | [precommit-trigger.yml](precommit-trigger.yml) | Pre-commit bot for PR |
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
| Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project |
| Integration Tests (Record) | [record-integration-tests.yml](record-integration-tests.yml) | Run the integration test suite from tests/integration |

View file

@ -84,6 +84,8 @@ jobs:
yq eval '.server.auth.provider_config.jwks.token = "${{ env.TOKEN }}"' -i $run_dir/run.yaml
cat $run_dir/run.yaml
# avoid line breaks in the server log, especially because we grep it below.
export COLUMNS=1984
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready

View file

@ -42,18 +42,27 @@ jobs:
run-replay-mode-tests:
runs-on: ubuntu-latest
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.setup, matrix.python-version, matrix.client-version, matrix.suite) }}
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.config.setup, matrix.python-version, matrix.client-version, matrix.config.suite) }}
strategy:
fail-fast: false
matrix:
client-type: [library, server]
# Use vllm on weekly schedule, otherwise use test-setup input (defaults to ollama)
setup: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-setup || 'ollama')) }}
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
suite: [base, vision]
# Define (setup, suite) pairs - they are always matched and cannot be independent
# Weekly schedule (Sun 1 AM): vllm+base
# Input test-setup=ollama-vision: ollama-vision+vision
# Default (including test-setup=ollama): both ollama+base and ollama-vision+vision
config: >-
${{
github.event.schedule == '1 0 * * 0'
&& fromJSON('[{"setup": "vllm", "suite": "base"}]')
|| github.event.inputs.test-setup == 'ollama-vision'
&& fromJSON('[{"setup": "ollama-vision", "suite": "vision"}]')
|| fromJSON('[{"setup": "ollama", "suite": "base"}, {"setup": "ollama-vision", "suite": "vision"}]')
}}
steps:
- name: Checkout repository
@ -64,14 +73,14 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
client-version: ${{ matrix.client-version }}
setup: ${{ matrix.setup }}
suite: ${{ matrix.suite }}
setup: ${{ matrix.config.setup }}
suite: ${{ matrix.config.suite }}
inference-mode: 'replay'
- name: Run tests
uses: ./.github/actions/run-and-record-tests
with:
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
setup: ${{ matrix.setup }}
setup: ${{ matrix.config.setup }}
inference-mode: 'replay'
suite: ${{ matrix.suite }}
suite: ${{ matrix.config.suite }}

227
.github/workflows/precommit-trigger.yml vendored Normal file
View file

@ -0,0 +1,227 @@
name: Pre-commit Bot
run-name: Pre-commit bot for PR #${{ github.event.issue.number }}
on:
issue_comment:
types: [created]
jobs:
pre-commit:
# Only run on pull request comments
if: github.event.issue.pull_request && contains(github.event.comment.body, '@github-actions run precommit')
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- name: Check comment author and get PR details
id: check_author
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
// Get PR details
const pr = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: context.issue.number
});
// Check if commenter has write access or is the PR author
const commenter = context.payload.comment.user.login;
const prAuthor = pr.data.user.login;
let hasPermission = false;
// Check if commenter is PR author
if (commenter === prAuthor) {
hasPermission = true;
console.log(`Comment author ${commenter} is the PR author`);
} else {
// Check if commenter has write/admin access
try {
const permission = await github.rest.repos.getCollaboratorPermissionLevel({
owner: context.repo.owner,
repo: context.repo.repo,
username: commenter
});
const level = permission.data.permission;
hasPermission = ['write', 'admin', 'maintain'].includes(level);
console.log(`Comment author ${commenter} has permission: ${level}`);
} catch (error) {
console.log(`Could not check permissions for ${commenter}: ${error.message}`);
}
}
if (!hasPermission) {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: `❌ @${commenter} You don't have permission to trigger pre-commit. Only PR authors or repository collaborators can run this command.`
});
core.setFailed(`User ${commenter} does not have permission`);
return;
}
// Save PR info for later steps
core.setOutput('pr_number', context.issue.number);
core.setOutput('pr_head_ref', pr.data.head.ref);
core.setOutput('pr_head_sha', pr.data.head.sha);
core.setOutput('pr_head_repo', pr.data.head.repo.full_name);
core.setOutput('pr_base_ref', pr.data.base.ref);
core.setOutput('is_fork', pr.data.head.repo.full_name !== context.payload.repository.full_name);
core.setOutput('authorized', 'true');
- name: React to comment
if: steps.check_author.outputs.authorized == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.reactions.createForIssueComment({
owner: context.repo.owner,
repo: context.repo.repo,
comment_id: context.payload.comment.id,
content: 'rocket'
});
- name: Comment starting
if: steps.check_author.outputs.authorized == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `⏳ Running pre-commit hooks on PR #${{ steps.check_author.outputs.pr_number }}...`
});
- name: Checkout PR branch (same-repo)
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'false'
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
ref: ${{ steps.check_author.outputs.pr_head_ref }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Checkout PR branch (fork)
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'true'
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
repository: ${{ steps.check_author.outputs.pr_head_repo }}
ref: ${{ steps.check_author.outputs.pr_head_ref }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Verify checkout
if: steps.check_author.outputs.authorized == 'true'
run: |
echo "Current SHA: $(git rev-parse HEAD)"
echo "Expected SHA: ${{ steps.check_author.outputs.pr_head_sha }}"
if [[ "$(git rev-parse HEAD)" != "${{ steps.check_author.outputs.pr_head_sha }}" ]]; then
echo "::error::Checked out SHA does not match expected SHA"
exit 1
fi
- name: Set up Python
if: steps.check_author.outputs.authorized == 'true'
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with:
python-version: '3.12'
cache: pip
cache-dependency-path: |
**/requirements*.txt
.pre-commit-config.yaml
- name: Set up Node.js
if: steps.check_author.outputs.authorized == 'true'
uses: actions/setup-node@a0853c24544627f65ddf259abe73b1d18a591444 # v5.0.0
with:
node-version: '20'
cache: 'npm'
cache-dependency-path: 'llama_stack/ui/'
- name: Install npm dependencies
if: steps.check_author.outputs.authorized == 'true'
run: npm ci
working-directory: llama_stack/ui
- name: Run pre-commit
if: steps.check_author.outputs.authorized == 'true'
id: precommit
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true
env:
SKIP: no-commit-to-branch
RUFF_OUTPUT_FORMAT: github
- name: Check for changes
if: steps.check_author.outputs.authorized == 'true'
id: changes
run: |
if ! git diff --exit-code || [ -n "$(git ls-files --others --exclude-standard)" ]; then
echo "has_changes=true" >> $GITHUB_OUTPUT
echo "Changes detected after pre-commit"
else
echo "has_changes=false" >> $GITHUB_OUTPUT
echo "No changes after pre-commit"
fi
- name: Commit and push changes
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git add -A
git commit -m "style: apply pre-commit fixes
🤖 Applied by @github-actions bot via pre-commit workflow"
# Push changes
git push origin HEAD:${{ steps.check_author.outputs.pr_head_ref }}
- name: Comment success with changes
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `✅ Pre-commit hooks completed successfully!\n\n🔧 Changes have been committed and pushed to the PR branch.`
});
- name: Comment success without changes
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'false' && steps.precommit.outcome == 'success'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `✅ Pre-commit hooks passed!\n\n✨ No changes needed - your code is already formatted correctly.`
});
- name: Comment failure
if: failure()
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `❌ Pre-commit workflow failed!\n\nPlease check the [workflow logs](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) for details.`
});

View file

@ -112,7 +112,7 @@ jobs:
fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
if [ "$entrypoint" != "[llama stack run /app/run.yaml]" ]; then
echo "Entrypoint is not correct"
exit 1
fi
@ -150,7 +150,7 @@ jobs:
fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
if [ "$entrypoint" != "[llama stack run /app/run.yaml]" ]; then
echo "Entrypoint is not correct"
exit 1
fi

View file

@ -24,7 +24,7 @@ jobs:
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install uv
uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4 # v6.7.0
uses: astral-sh/setup-uv@d0cc045d04ccac9d8b7881df0226f9e82c39688e # v6.8.0
with:
python-version: ${{ matrix.python-version }}
activate-environment: true

View file

@ -7,7 +7,7 @@
[![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
[![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
[**Quick Start**](https://llamastack.github.io/latest/getting_started/index.html) | [**Documentation**](https://llamastack.github.io/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
[**Quick Start**](https://llamastack.github.io/docs/getting_started/quickstart) | [**Documentation**](https://llamastack.github.io/docs) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
### ✨🎉 Llama 4 Support 🎉✨

View file

@ -187,21 +187,21 @@ Configure telemetry behavior using environment variables:
- **`OTEL_SERVICE_NAME`**: Service name for telemetry (default: empty string)
- **`TELEMETRY_SINKS`**: Comma-separated list of sinks (default: `console,sqlite`)
## Visualization with Jaeger
### Quick Setup: Complete Telemetry Stack
The `otel_trace` sink works with any service compatible with the OpenTelemetry collector. Traces and metrics use separate endpoints but can share the same collector.
### Starting Jaeger
Start a Jaeger instance with OTLP HTTP endpoint at 4318 and the Jaeger UI at 16686:
Use the automated setup script to launch the complete telemetry stack (Jaeger, OpenTelemetry Collector, Prometheus, and Grafana):
```bash
docker run --pull always --rm --name jaeger \
-p 16686:16686 -p 4318:4318 \
jaegertracing/jaeger:2.1.0
./scripts/telemetry/setup_telemetry.sh
```
Once running, you can visualize traces by navigating to [http://localhost:16686/](http://localhost:16686/).
This sets up:
- **Jaeger UI**: http://localhost:16686 (traces visualization)
- **Prometheus**: http://localhost:9090 (metrics)
- **Grafana**: http://localhost:3000 (dashboards with auto-configured data sources)
- **OTEL Collector**: http://localhost:4318 (OTLP endpoint)
Once running, you can visualize traces by navigating to [Grafana](http://localhost:3000/) and login with login `admin` and password `admin`.
## Querying Metrics

View file

@ -152,7 +152,6 @@ __all__ = ["WeatherAPI", "available_providers"]
from typing import Protocol
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
ProviderSpec,
RemoteProviderSpec,
@ -166,12 +165,10 @@ def available_providers() -> list[ProviderSpec]:
api=Api.weather,
provider_type="remote::kaze",
config_class="llama_stack_provider_kaze.KazeProviderConfig",
adapter=AdapterSpec(
adapter_type="kaze",
module="llama_stack_provider_kaze",
pip_packages=["llama_stack_provider_kaze"],
config_class="llama_stack_provider_kaze.KazeProviderConfig",
),
adapter_type="kaze",
module="llama_stack_provider_kaze",
pip_packages=["llama_stack_provider_kaze"],
config_class="llama_stack_provider_kaze.KazeProviderConfig",
),
]
@ -325,11 +322,10 @@ class WeatherKazeAdapter(WeatherProvider):
```yaml
# ~/.llama/providers.d/remote/weather/kaze.yaml
adapter:
adapter_type: kaze
pip_packages: ["llama_stack_provider_kaze"]
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
module: llama_stack_provider_kaze
adapter_type: kaze
pip_packages: ["llama_stack_provider_kaze"]
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
module: llama_stack_provider_kaze
optional_api_dependencies: []
```
@ -361,7 +357,7 @@ server:
8. Run the server:
```bash
python -m llama_stack.core.server.server --yaml-config ~/.llama/run-byoa.yaml
llama stack run ~/.llama/run-byoa.yaml
```
9. Test the API:

View file

@ -170,7 +170,7 @@ spec:
- name: llama-stack
image: localhost/llama-stack-run-k8s:latest
imagePullPolicy: IfNotPresent
command: ["python", "-m", "llama_stack.core.server.server", "--config", "/app/config.yaml"]
command: ["llama", "stack", "run", "/app/config.yaml"]
ports:
- containerPort: 5000
volumeMounts:

View file

@ -509,16 +509,16 @@ server:
provider_config:
type: "github_token"
github_api_base_url: "https://api.github.com"
access_policy:
- permit:
principal: user-1
actions: [create, read, delete]
description: user-1 has full access to all resources
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
access_policy:
- permit:
principal: user-1
actions: [create, read, delete]
description: user-1 has full access to all resources
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
```
Similarly, the following restricts access to particular kubernetes

View file

@ -52,7 +52,7 @@ spec:
value: "${SAFETY_MODEL}"
- name: TAVILY_SEARCH_API_KEY
value: "${TAVILY_SEARCH_API_KEY}"
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8321"]
command: ["llama", "stack", "run", "/etc/config/stack_run_config.yaml", "--port", "8321"]
ports:
- containerPort: 8321
volumeMounts:

View file

@ -11,38 +11,6 @@ an example entry in your build.yaml should look like:
module: ramalama_stack
```
Additionally you can configure the `external_providers_dir` in your Llama Stack configuration. This method is in the process of being deprecated in favor of the `module` method. If using this method, the external provider directory should contain your external provider specifications:
```yaml
external_providers_dir: ~/.llama/providers.d/
```
## Directory Structure
The external providers directory should follow this structure:
```
providers.d/
remote/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
inline/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
```
Each YAML file in these directories defines a provider specification for that particular API.
## Provider Types
Llama Stack supports two types of external providers:
@ -50,30 +18,37 @@ Llama Stack supports two types of external providers:
1. **Remote Providers**: Providers that communicate with external services (e.g., cloud APIs)
2. **Inline Providers**: Providers that run locally within the Llama Stack process
### Provider Specification (Common between inline and remote providers)
- `provider_type`: The type of the provider to be installed (remote or inline). eg. `remote::ollama`
- `api`: The API for this provider, eg. `inference`
- `config_class`: The full path to the configuration class
- `module`: The Python module containing the provider implementation
- `optional_api_dependencies`: List of optional Llama Stack APIs that this provider can use
- `api_dependencies`: List of Llama Stack APIs that this provider depends on
- `provider_data_validator`: Optional validator for provider data.
- `pip_packages`: List of Python packages required by the provider
### Remote Provider Specification
Remote providers are used when you need to communicate with external services. Here's an example for a custom Ollama provider:
```yaml
adapter:
adapter_type: custom_ollama
pip_packages:
- ollama
- aiohttp
config_class: llama_stack_ollama_provider.config.OllamaImplConfig
module: llama_stack_ollama_provider
adapter_type: custom_ollama
provider_type: "remote::ollama"
pip_packages:
- ollama
- aiohttp
config_class: llama_stack_ollama_provider.config.OllamaImplConfig
module: llama_stack_ollama_provider
api_dependencies: []
optional_api_dependencies: []
```
#### Adapter Configuration
#### Remote Provider Configuration
The `adapter` section defines how to load and configure the provider:
- `adapter_type`: A unique identifier for this adapter
- `pip_packages`: List of Python packages required by the provider
- `config_class`: The full path to the configuration class
- `module`: The Python module containing the provider implementation
- `adapter_type`: A unique identifier for this adapter, eg. `ollama`
### Inline Provider Specification
@ -81,6 +56,7 @@ Inline providers run locally within the Llama Stack process. Here's an example f
```yaml
module: llama_stack_vector_provider
provider_type: inline::llama_stack_vector_provider
config_class: llama_stack_vector_provider.config.VectorStoreConfig
pip_packages:
- faiss-cpu
@ -95,12 +71,6 @@ container_image: custom-vector-store:latest # optional
#### Inline Provider Fields
- `module`: The Python module containing the provider implementation
- `config_class`: The full path to the configuration class
- `pip_packages`: List of Python packages required by the provider
- `api_dependencies`: List of Llama Stack APIs that this provider depends on
- `optional_api_dependencies`: List of optional Llama Stack APIs that this provider can use
- `provider_data_validator`: Optional validator for provider data
- `container_image`: Optional container image to use instead of pip packages
## Required Fields
@ -113,20 +83,17 @@ All providers must contain a `get_provider_spec` function in their `provider` mo
from llama_stack.providers.datatypes import (
ProviderSpec,
Api,
AdapterSpec,
remote_provider_spec,
RemoteProviderSpec,
)
def get_provider_spec() -> ProviderSpec:
return remote_provider_spec(
return RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="ramalama",
pip_packages=["ramalama>=0.8.5", "pymilvus"],
config_class="ramalama_stack.config.RamalamaImplConfig",
module="ramalama_stack",
),
adapter_type="ramalama",
pip_packages=["ramalama>=0.8.5", "pymilvus"],
config_class="ramalama_stack.config.RamalamaImplConfig",
module="ramalama_stack",
)
```
@ -197,18 +164,16 @@ information. Execute the test for the Provider type you are developing.
If your external provider isn't being loaded:
1. Check that `module` points to a published pip package with a top level `provider` module including `get_provider_spec`.
1. Check that the `external_providers_dir` path is correct and accessible.
2. Verify that the YAML files are properly formatted.
3. Ensure all required Python packages are installed.
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more
information using `LLAMA_STACK_LOGGING=all=debug`.
5. Verify that the provider package is installed in your Python environment if using `external_providers_dir`.
## Examples
### Example using `external_providers_dir`: Custom Ollama Provider
### How to create an external provider module
Here's a complete example of creating and using a custom Ollama provider:
If you are creating a new external provider called `llama-stack-provider-ollama` here is how you would set up the package properly:
1. First, create the provider package:
@ -230,33 +195,28 @@ requires-python = ">=3.12"
dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
```
3. Create the provider specification:
```yaml
# ~/.llama/providers.d/remote/inference/custom_ollama.yaml
adapter:
adapter_type: custom_ollama
pip_packages: ["ollama", "aiohttp"]
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
module: llama_stack_provider_ollama
api_dependencies: []
optional_api_dependencies: []
```
4. Install the provider:
3. Install the provider:
```bash
uv pip install -e .
```
5. Configure Llama Stack to use external providers:
4. Edit `provider.py`
```yaml
external_providers_dir: ~/.llama/providers.d/
provider.py must be updated to contain `get_provider_spec`. This is used by llama stack to install the provider.
```python
def get_provider_spec() -> ProviderSpec:
return RemoteProviderSpec(
api=Api.inference,
adapter_type="llama-stack-provider-ollama",
pip_packages=["ollama", "aiohttp"],
config_class="llama_stack_provider_ollama.config.OllamaImplConfig",
module="llama_stack_provider_ollama",
)
```
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
5. Implement the provider as outlined above with `get_provider_impl` or `get_adapter_impl`, etc.
### Example using `module`: ramalama-stack
@ -275,7 +235,6 @@ distribution_spec:
module: ramalama_stack==0.3.0a0
image_type: venv
image_name: null
external_providers_dir: null
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]

View file

@ -1,4 +1,7 @@
---
description: "Files
This API is used to upload documents that can be used with other Llama Stack APIs."
sidebar_label: Files
title: Files
---
@ -7,4 +10,8 @@ title: Files
## Overview
Files
This API is used to upload documents that can be used with other Llama Stack APIs.
This section contains documentation for all available providers for the **files** API.

View file

@ -1,5 +1,7 @@
---
description: "Llama Stack Inference API for generating completions, chat completions, and embeddings.
description: "Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
@ -12,7 +14,9 @@ title: Inference
## Overview
Llama Stack Inference API for generating completions, chat completions, and embeddings.
Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.

View file

@ -14,6 +14,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | API key for Anthropic models |
## Sample Configuration

View file

@ -21,6 +21,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Azure API key for Azure |
| `api_base` | `<class 'pydantic.networks.HttpUrl'>` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |

View file

@ -14,6 +14,7 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |

View file

@ -14,6 +14,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Cerebras API Key |

View file

@ -14,7 +14,8 @@ Databricks inference provider for running models on Databricks' unified analytic
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `url` | `<class 'str'>` | No | | The URL for the Databricks model serving endpoint |
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
| `api_token` | `<class 'pydantic.types.SecretStr'>` | No | | The Databricks API token |
## Sample Configuration

View file

@ -14,6 +14,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | API key for Gemini models |
## Sample Configuration

View file

@ -14,6 +14,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | The Groq API key |
| `url` | `<class 'str'>` | No | https://api.groq.com | The URL for the Groq AI server |

View file

@ -14,6 +14,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | The Llama API key |
| `openai_compat_api_base` | `<class 'str'>` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |

View file

@ -14,6 +14,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service |
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |

View file

@ -14,6 +14,7 @@ Ollama inference provider for running local models through the Ollama runtime.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |

View file

@ -14,6 +14,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `api_key` | `str \| None` | No | | API key for OpenAI models |
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |

View file

@ -14,6 +14,7 @@ Passthrough inference provider for connecting to any external inference service
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | | The URL for the passthrough endpoint |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint |

View file

@ -14,6 +14,7 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
| `api_token` | `str \| None` | No | | The API token |

View file

@ -14,6 +14,7 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key |

View file

@ -14,6 +14,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | | The URL for the TGI serving endpoint |
## Sample Configuration

View file

@ -53,6 +53,7 @@ Available Models:
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `project` | `<class 'str'>` | No | | Google Cloud project ID for Vertex AI |
| `location` | `<class 'str'>` | No | us-central1 | Google Cloud location for Vertex AI |

View file

@ -14,6 +14,7 @@ Remote vLLM inference provider for connecting to vLLM servers.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
| `max_tokens` | `<class 'int'>` | No | 4096 | Maximum number of tokens to generate. |
| `api_token` | `str \| None` | No | fake | The API token |

View file

@ -14,6 +14,7 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key |
| `project_id` | `str \| None` | No | | The Project ID key |

View file

@ -1,4 +1,7 @@
---
description: "Safety
OpenAI-compatible Moderations API."
sidebar_label: Safety
title: Safety
---
@ -7,4 +10,8 @@ title: Safety
## Overview
Safety
OpenAI-compatible Moderations API.
This section contains documentation for all available providers for the **safety** API.

View file

@ -14,6 +14,7 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |

View file

@ -50,6 +50,7 @@ from .specification import (
Document,
Example,
ExampleRef,
ExtraBodyParameter,
MediaType,
Operation,
Parameter,
@ -677,6 +678,27 @@ class Generator:
# parameters passed anywhere
parameters = path_parameters + query_parameters
# Build extra body parameters documentation
extra_body_parameters = []
for param_name, param_type, description in op.extra_body_params:
if is_type_optional(param_type):
inner_type: type = unwrap_optional_type(param_type)
required = False
else:
inner_type = param_type
required = True
# Use description from ExtraBodyField if available, otherwise from docstring
param_description = description or doc_params.get(param_name)
extra_body_param = ExtraBodyParameter(
name=param_name,
schema=self.schema_builder.classdef_to_ref(inner_type),
description=param_description,
required=required,
)
extra_body_parameters.append(extra_body_param)
webmethod = getattr(op.func_ref, "__webmethod__", None)
raw_bytes_request_body = False
if webmethod:
@ -898,6 +920,7 @@ class Generator:
deprecated=getattr(op.webmethod, "deprecated", False)
or "DEPRECATED" in op.func_name,
security=[] if op.public else None,
extraBodyParameters=extra_body_parameters if extra_body_parameters else None,
)
def _get_api_stability_priority(self, api_level: str) -> int:

View file

@ -23,6 +23,8 @@ from fastapi import UploadFile
from fastapi.params import File, Form
from typing import Annotated
from llama_stack.schema_utils import ExtraBodyField
def split_prefix(
s: str, sep: str, prefix: Union[str, Iterable[str]]
@ -89,6 +91,7 @@ class EndpointOperation:
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
:param request_params: The parameter that corresponds to the data transmitted in the request body.
:param multipart_params: Parameters that indicate multipart/form-data request body.
:param extra_body_params: Parameters that arrive via extra_body and are documented but not in SDK.
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
:param response_type: The Python type of the data that is transmitted in the response body.
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
@ -106,6 +109,7 @@ class EndpointOperation:
query_params: List[OperationParameter]
request_params: Optional[OperationParameter]
multipart_params: List[OperationParameter]
extra_body_params: List[tuple[str, type, str | None]]
event_type: Optional[type]
response_type: type
http_method: HTTPMethod
@ -265,6 +269,7 @@ def get_endpoint_operations(
query_params = []
request_params = []
multipart_params = []
extra_body_params = []
for param_name, parameter in signature.parameters.items():
param_type = _get_annotation_type(parameter.annotation, func_ref)
@ -279,6 +284,13 @@ def get_endpoint_operations(
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
)
# Check if this is an extra_body parameter
is_extra_body, extra_body_desc = _is_extra_body_param(param_type)
if is_extra_body:
# Store in a separate list for documentation
extra_body_params.append((param_name, param_type, extra_body_desc))
continue # Skip adding to request_params
is_multipart = _is_multipart_param(param_type)
if prefix in ["get", "delete"]:
@ -351,6 +363,7 @@ def get_endpoint_operations(
query_params=query_params,
request_params=request_params,
multipart_params=multipart_params,
extra_body_params=extra_body_params,
event_type=event_type,
response_type=response_type,
http_method=http_method,
@ -429,3 +442,22 @@ def _is_multipart_param(param_type: type) -> bool:
if isinstance(annotation, (File, Form)):
return True
return False
def _is_extra_body_param(param_type: type) -> tuple[bool, str | None]:
"""
Check if parameter is marked as coming from extra_body.
Returns:
(is_extra_body, description): Tuple of boolean and optional description
"""
origin = get_origin(param_type)
if origin is Annotated:
args = get_args(param_type)
for annotation in args[1:]:
if isinstance(annotation, ExtraBodyField):
return True, annotation.description
# Also check by type name for cases where import matters
if type(annotation).__name__ == 'ExtraBodyField':
return True, getattr(annotation, 'description', None)
return False, None

View file

@ -106,6 +106,15 @@ class Parameter:
example: Optional[Any] = None
@dataclass
class ExtraBodyParameter:
"""Represents a parameter that arrives via extra_body in the request."""
name: str
schema: SchemaOrRef
description: Optional[str] = None
required: Optional[bool] = None
@dataclass
class Operation:
responses: Dict[str, Union[Response, ResponseRef]]
@ -118,6 +127,7 @@ class Operation:
callbacks: Optional[Dict[str, "Callback"]] = None
security: Optional[List["SecurityRequirement"]] = None
deprecated: Optional[bool] = None
extraBodyParameters: Optional[List[ExtraBodyParameter]] = None
@dataclass

View file

@ -52,6 +52,17 @@ class Specification:
if display_name:
tag["x-displayName"] = display_name
# Handle operations to rename extraBodyParameters -> x-llama-stack-extra-body-params
paths = json_doc.get("paths", {})
for path_item in paths.values():
if isinstance(path_item, dict):
for method in ["get", "post", "put", "delete", "patch"]:
operation = path_item.get(method)
if operation and isinstance(operation, dict):
extra_body_params = operation.pop("extraBodyParameters", None)
if extra_body_params:
operation["x-llama-stack-extra-body-params"] = extra_body_params
return json_doc
def get_json_string(self, pretty_print: bool = False) -> str:

View file

@ -1443,8 +1443,8 @@
"tags": [
"Inference"
],
"summary": "List all chat completions.",
"description": "List all chat completions.",
"summary": "List chat completions.",
"description": "List chat completions.",
"parameters": [
{
"name": "after",
@ -1520,8 +1520,8 @@
"tags": [
"Inference"
],
"summary": "Generate an OpenAI-compatible chat completion for the given messages using the specified model.",
"description": "Generate an OpenAI-compatible chat completion for the given messages using the specified model.",
"summary": "Create chat completions.",
"description": "Create chat completions.\nGenerate an OpenAI-compatible chat completion for the given messages using the specified model.",
"parameters": [],
"requestBody": {
"content": {
@ -1565,8 +1565,8 @@
"tags": [
"Inference"
],
"summary": "Describe a chat completion by its ID.",
"description": "Describe a chat completion by its ID.",
"summary": "Get chat completion.",
"description": "Get chat completion.\nDescribe a chat completion by its ID.",
"parameters": [
{
"name": "completion_id",
@ -1610,8 +1610,8 @@
"tags": [
"Inference"
],
"summary": "Generate an OpenAI-compatible completion for the given prompt using the specified model.",
"description": "Generate an OpenAI-compatible completion for the given prompt using the specified model.",
"summary": "Create completion.",
"description": "Create completion.\nGenerate an OpenAI-compatible completion for the given prompt using the specified model.",
"parameters": [],
"requestBody": {
"content": {
@ -1655,8 +1655,8 @@
"tags": [
"Inference"
],
"summary": "Generate OpenAI-compatible embeddings for the given input using the specified model.",
"description": "Generate OpenAI-compatible embeddings for the given input using the specified model.",
"summary": "Create embeddings.",
"description": "Create embeddings.\nGenerate OpenAI-compatible embeddings for the given input using the specified model.",
"parameters": [],
"requestBody": {
"content": {
@ -1700,8 +1700,8 @@
"tags": [
"Files"
],
"summary": "Returns a list of files that belong to the user's organization.",
"description": "Returns a list of files that belong to the user's organization.",
"summary": "List files.",
"description": "List files.\nReturns a list of files that belong to the user's organization.",
"parameters": [
{
"name": "after",
@ -1770,8 +1770,8 @@
"tags": [
"Files"
],
"summary": "Upload a file that can be used across various endpoints.",
"description": "Upload a file that can be used across various endpoints.\nThe file upload should be a multipart form request with:\n- file: The File object (not file name) to be uploaded.\n- purpose: The intended purpose of the uploaded file.\n- expires_after: Optional form values describing expiration for the file.",
"summary": "Upload file.",
"description": "Upload file.\nUpload a file that can be used across various endpoints.\n\nThe file upload should be a multipart form request with:\n- file: The File object (not file name) to be uploaded.\n- purpose: The intended purpose of the uploaded file.\n- expires_after: Optional form values describing expiration for the file.",
"parameters": [],
"requestBody": {
"content": {
@ -1831,8 +1831,8 @@
"tags": [
"Files"
],
"summary": "Returns information about a specific file.",
"description": "Returns information about a specific file.",
"summary": "Retrieve file.",
"description": "Retrieve file.\nReturns information about a specific file.",
"parameters": [
{
"name": "file_id",
@ -1874,8 +1874,8 @@
"tags": [
"Files"
],
"summary": "Delete a file.",
"description": "Delete a file.",
"summary": "Delete file.",
"description": "Delete file.",
"parameters": [
{
"name": "file_id",
@ -1919,8 +1919,8 @@
"tags": [
"Files"
],
"summary": "Returns the contents of the specified file.",
"description": "Returns the contents of the specified file.",
"summary": "Retrieve file content.",
"description": "Retrieve file content.\nReturns the contents of the specified file.",
"parameters": [
{
"name": "file_id",
@ -1999,8 +1999,8 @@
"tags": [
"Safety"
],
"summary": "Classifies if text and/or image inputs are potentially harmful.",
"description": "Classifies if text and/or image inputs are potentially harmful.",
"summary": "Create moderation.",
"description": "Create moderation.\nClassifies if text and/or image inputs are potentially harmful.",
"parameters": [],
"requestBody": {
"content": {
@ -2044,8 +2044,8 @@
"tags": [
"Agents"
],
"summary": "List all OpenAI responses.",
"description": "List all OpenAI responses.",
"summary": "List all responses.",
"description": "List all responses.",
"parameters": [
{
"name": "after",
@ -2119,8 +2119,8 @@
"tags": [
"Agents"
],
"summary": "Create a new OpenAI response.",
"description": "Create a new OpenAI response.",
"summary": "Create a model response.",
"description": "Create a model response.",
"parameters": [],
"requestBody": {
"content": {
@ -2132,7 +2132,27 @@
},
"required": true
},
"deprecated": true
"deprecated": true,
"x-llama-stack-extra-body-params": [
{
"name": "shields",
"schema": {
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ResponseShieldSpec"
}
]
}
},
"description": "List of shields to apply during response generation. Shields provide safety and content moderation.",
"required": false
}
]
}
},
"/v1/openai/v1/responses/{response_id}": {
@ -2164,8 +2184,8 @@
"tags": [
"Agents"
],
"summary": "Retrieve an OpenAI response by its ID.",
"description": "Retrieve an OpenAI response by its ID.",
"summary": "Get a model response.",
"description": "Get a model response.",
"parameters": [
{
"name": "response_id",
@ -2207,8 +2227,8 @@
"tags": [
"Agents"
],
"summary": "Delete an OpenAI response by its ID.",
"description": "Delete an OpenAI response by its ID.",
"summary": "Delete a response.",
"description": "Delete a response.",
"parameters": [
{
"name": "response_id",
@ -2252,8 +2272,8 @@
"tags": [
"Agents"
],
"summary": "List input items for a given OpenAI response.",
"description": "List input items for a given OpenAI response.",
"summary": "List input items.",
"description": "List input items.",
"parameters": [
{
"name": "response_id",
@ -9521,6 +9541,21 @@
"title": "OpenAIResponseText",
"description": "Text response configuration for OpenAI responses."
},
"ResponseShieldSpec": {
"type": "object",
"properties": {
"type": {
"type": "string",
"description": "The type/identifier of the shield."
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "ResponseShieldSpec",
"description": "Specification for a shield to apply during response generation."
},
"OpenAIResponseInputTool": {
"oneOf": [
{
@ -13331,12 +13366,13 @@
},
{
"name": "Files",
"description": ""
"description": "This API is used to upload documents that can be used with other Llama Stack APIs.",
"x-displayName": "Files"
},
{
"name": "Inference",
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.",
"x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings."
"description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.",
"x-displayName": "Inference"
},
{
"name": "Models",
@ -13348,7 +13384,8 @@
},
{
"name": "Safety",
"description": ""
"description": "OpenAI-compatible Moderations API.",
"x-displayName": "Safety"
},
{
"name": "Telemetry",

View file

@ -1033,8 +1033,8 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
summary: List all chat completions.
description: List all chat completions.
summary: List chat completions.
description: List chat completions.
parameters:
- name: after
in: query
@ -1087,10 +1087,10 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
summary: >-
Generate an OpenAI-compatible chat completion for the given messages using
the specified model.
summary: Create chat completions.
description: >-
Create chat completions.
Generate an OpenAI-compatible chat completion for the given messages using
the specified model.
parameters: []
@ -1122,8 +1122,11 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
summary: Describe a chat completion by its ID.
description: Describe a chat completion by its ID.
summary: Get chat completion.
description: >-
Get chat completion.
Describe a chat completion by its ID.
parameters:
- name: completion_id
in: path
@ -1153,10 +1156,10 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
summary: >-
Generate an OpenAI-compatible completion for the given prompt using the specified
model.
summary: Create completion.
description: >-
Create completion.
Generate an OpenAI-compatible completion for the given prompt using the specified
model.
parameters: []
@ -1189,10 +1192,10 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
summary: >-
Generate OpenAI-compatible embeddings for the given input using the specified
model.
summary: Create embeddings.
description: >-
Create embeddings.
Generate OpenAI-compatible embeddings for the given input using the specified
model.
parameters: []
@ -1225,9 +1228,10 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Files
summary: >-
Returns a list of files that belong to the user's organization.
summary: List files.
description: >-
List files.
Returns a list of files that belong to the user's organization.
parameters:
- name: after
@ -1285,11 +1289,13 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Files
summary: >-
Upload a file that can be used across various endpoints.
summary: Upload file.
description: >-
Upload file.
Upload a file that can be used across various endpoints.
The file upload should be a multipart form request with:
- file: The File object (not file name) to be uploaded.
@ -1338,9 +1344,10 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Files
summary: >-
Returns information about a specific file.
summary: Retrieve file.
description: >-
Retrieve file.
Returns information about a specific file.
parameters:
- name: file_id
@ -1372,8 +1379,8 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Files
summary: Delete a file.
description: Delete a file.
summary: Delete file.
description: Delete file.
parameters:
- name: file_id
in: path
@ -1405,9 +1412,10 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Files
summary: >-
Returns the contents of the specified file.
summary: Retrieve file content.
description: >-
Retrieve file content.
Returns the contents of the specified file.
parameters:
- name: file_id
@ -1464,9 +1472,10 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Safety
summary: >-
Classifies if text and/or image inputs are potentially harmful.
summary: Create moderation.
description: >-
Create moderation.
Classifies if text and/or image inputs are potentially harmful.
parameters: []
requestBody:
@ -1497,8 +1506,8 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
summary: List all OpenAI responses.
description: List all OpenAI responses.
summary: List all responses.
description: List all responses.
parameters:
- name: after
in: query
@ -1549,8 +1558,8 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
summary: Create a new OpenAI response.
description: Create a new OpenAI response.
summary: Create a model response.
description: Create a model response.
parameters: []
requestBody:
content:
@ -1559,6 +1568,18 @@ paths:
$ref: '#/components/schemas/CreateOpenaiResponseRequest'
required: true
deprecated: true
x-llama-stack-extra-body-params:
- name: shields
schema:
type: array
items:
oneOf:
- type: string
- $ref: '#/components/schemas/ResponseShieldSpec'
description: >-
List of shields to apply during response generation. Shields provide safety
and content moderation.
required: false
/v1/openai/v1/responses/{response_id}:
get:
responses:
@ -1580,8 +1601,8 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
summary: Retrieve an OpenAI response by its ID.
description: Retrieve an OpenAI response by its ID.
summary: Get a model response.
description: Get a model response.
parameters:
- name: response_id
in: path
@ -1611,8 +1632,8 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
summary: Delete an OpenAI response by its ID.
description: Delete an OpenAI response by its ID.
summary: Delete a response.
description: Delete a response.
parameters:
- name: response_id
in: path
@ -1642,10 +1663,8 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
summary: >-
List input items for a given OpenAI response.
description: >-
List input items for a given OpenAI response.
summary: List input items.
description: List input items.
parameters:
- name: response_id
in: path
@ -7076,6 +7095,18 @@ components:
title: OpenAIResponseText
description: >-
Text response configuration for OpenAI responses.
ResponseShieldSpec:
type: object
properties:
type:
type: string
description: The type/identifier of the shield.
additionalProperties: false
required:
- type
title: ResponseShieldSpec
description: >-
Specification for a shield to apply during response generation.
OpenAIResponseInputTool:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'
@ -9987,9 +10018,16 @@ tags:
x-displayName: >-
Llama Stack Evaluation API for running evaluations on model and agent candidates.
- name: Files
description: ''
description: >-
This API is used to upload documents that can be used with other Llama Stack
APIs.
x-displayName: Files
- name: Inference
description: >-
Llama Stack Inference API for generating completions, chat completions, and
embeddings.
This API provides the raw interface to the underlying models. Two kinds of models
are supported:
@ -9997,15 +10035,14 @@ tags:
- Embedding models: these models generate embeddings to be used for semantic
search.
x-displayName: >-
Llama Stack Inference API for generating completions, chat completions, and
embeddings.
x-displayName: Inference
- name: Models
description: ''
- name: PostTraining (Coming Soon)
description: ''
- name: Safety
description: ''
description: OpenAI-compatible Moderations API.
x-displayName: Safety
- name: Telemetry
description: ''
- name: VectorIO

Binary file not shown.

Before

Width:  |  Height:  |  Size: 196 KiB

After

Width:  |  Height:  |  Size: 604 KiB

Before After
Before After

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
from .openai_responses import (
ListOpenAIResponseInputItem,
@ -42,6 +42,20 @@ from .openai_responses import (
)
@json_schema_type
class ResponseShieldSpec(BaseModel):
"""Specification for a shield to apply during response generation.
:param type: The type/identifier of the shield.
"""
type: str
# TODO: more fields to be added for shield configuration
ResponseShield = str | ResponseShieldSpec
class Attachment(BaseModel):
"""An attachment to an agent turn.
@ -783,7 +797,7 @@ class Agents(Protocol):
self,
response_id: str,
) -> OpenAIResponseObject:
"""Retrieve an OpenAI response by its ID.
"""Get a model response.
:param response_id: The ID of the OpenAI response to retrieve.
:returns: An OpenAIResponseObject.
@ -805,13 +819,20 @@ class Agents(Protocol):
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
shields: Annotated[
list[ResponseShield] | None,
ExtraBodyField(
"List of shields to apply during response generation. Shields provide safety and content moderation."
),
] = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a new OpenAI response.
"""Create a model response.
:param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param include: (Optional) Additional fields to include in the response.
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
:returns: An OpenAIResponseObject.
"""
...
@ -825,7 +846,7 @@ class Agents(Protocol):
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
"""List all OpenAI responses.
"""List all responses.
:param after: The ID of the last response to return.
:param limit: The number of responses to return.
@ -848,7 +869,7 @@ class Agents(Protocol):
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items for a given OpenAI response.
"""List input items.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
@ -863,7 +884,7 @@ class Agents(Protocol):
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
"""Delete an OpenAI response by its ID.
"""Delete a response.
:param response_id: The ID of the OpenAI response to delete.
:returns: An OpenAIDeleteResponseObject

View file

@ -888,6 +888,10 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
input: list[OpenAIResponseInput]
def to_response_object(self) -> OpenAIResponseObject:
"""Convert to OpenAIResponseObject by excluding input field."""
return OpenAIResponseObject(**{k: v for k, v in self.model_dump().items() if k != "input"})
@json_schema_type
class ListOpenAIResponseObject(BaseModel):

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .conversations import (
Conversation,
ConversationCreateRequest,
ConversationDeletedResource,
ConversationItem,
ConversationItemCreateRequest,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
ConversationUpdateRequest,
Metadata,
)
__all__ = [
"Conversation",
"ConversationCreateRequest",
"ConversationDeletedResource",
"ConversationItem",
"ConversationItemCreateRequest",
"ConversationItemDeletedResource",
"ConversationItemList",
"Conversations",
"ConversationUpdateRequest",
"Metadata",
]

View file

@ -0,0 +1,260 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Literal, Protocol, runtime_checkable
from openai import NOT_GIVEN
from openai._types import NotGiven
from openai.types.responses.response_includable import ResponseIncludable
from pydantic import BaseModel, Field
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
Metadata = dict[str, str]
@json_schema_type
class Conversation(BaseModel):
"""OpenAI-compatible conversation object."""
id: str = Field(..., description="The unique ID of the conversation.")
object: Literal["conversation"] = Field(
default="conversation", description="The object type, which is always conversation."
)
created_at: int = Field(
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
)
metadata: Metadata | None = Field(
default=None,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
)
items: list[dict] | None = Field(
default=None,
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
)
@json_schema_type
class ConversationMessage(BaseModel):
"""OpenAI-compatible message item for conversations."""
id: str = Field(..., description="unique identifier for this message")
content: list[dict] = Field(..., description="message content")
role: str = Field(..., description="message role")
status: str = Field(..., description="message status")
type: Literal["message"] = "message"
object: Literal["message"] = "message"
ConversationItem = Annotated[
OpenAIResponseMessage
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"),
]
register_schema(ConversationItem, name="ConversationItem")
# Using OpenAI types directly caused issues but some notes for reference:
# Note that ConversationItem is a Annotated Union of the types below:
# from openai.types.responses import *
# from openai.types.responses.response_item import *
# from openai.types.conversations import ConversationItem
# f = [
# ResponseFunctionToolCallItem,
# ResponseFunctionToolCallOutputItem,
# ResponseFileSearchToolCall,
# ResponseFunctionWebSearch,
# ImageGenerationCall,
# ResponseComputerToolCall,
# ResponseComputerToolCallOutputItem,
# ResponseReasoningItem,
# ResponseCodeInterpreterToolCall,
# LocalShellCall,
# LocalShellCallOutput,
# McpListTools,
# McpApprovalRequest,
# McpApprovalResponse,
# McpCall,
# ResponseCustomToolCall,
# ResponseCustomToolCallOutput
# ]
@json_schema_type
class ConversationCreateRequest(BaseModel):
"""Request body for creating a conversation."""
items: list[ConversationItem] | None = Field(
default=[],
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
metadata: Metadata | None = Field(
default={},
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
max_length=16,
)
@json_schema_type
class ConversationUpdateRequest(BaseModel):
"""Request body for updating a conversation."""
metadata: Metadata = Field(
...,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
)
@json_schema_type
class ConversationDeletedResource(BaseModel):
"""Response for deleted conversation."""
id: str = Field(..., description="The deleted conversation identifier")
object: str = Field(default="conversation.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted")
@json_schema_type
class ConversationItemCreateRequest(BaseModel):
"""Request body for creating conversation items."""
items: list[ConversationItem] = Field(
...,
description="Items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
@json_schema_type
class ConversationItemList(BaseModel):
"""List of conversation items with pagination."""
object: str = Field(default="list", description="Object type")
data: list[ConversationItem] = Field(..., description="List of conversation items")
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
has_more: bool = Field(default=False, description="Whether there are more items available")
@json_schema_type
class ConversationItemDeletedResource(BaseModel):
"""Response for deleted conversation item."""
id: str = Field(..., description="The deleted item identifier")
object: str = Field(default="conversation.item.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted")
@runtime_checkable
@trace_protocol
class Conversations(Protocol):
"""Protocol for conversation management operations."""
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation.
:param items: Initial items to include in the conversation context.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The created conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID.
:param conversation_id: The conversation identifier.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The updated conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The deleted conversation resource.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create items in the conversation.
:param conversation_id: The conversation identifier.
:param items: Items to include in the conversation context.
:returns: List of created items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The conversation item.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
async def list(
self,
conversation_id: str,
after: str | NotGiven = NOT_GIVEN,
include: list[ResponseIncludable] | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
) -> ConversationItemList:
"""List items in the conversation.
:param conversation_id: The conversation identifier.
:param after: An item ID to list items after, used in pagination.
:param include: Specify additional output data to include in the response.
:param limit: A limit on the number of objects to be returned (1-100, default 20).
:param order: The order to return items in (asc or desc, default desc).
:returns: List of conversation items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The deleted item resource.
"""
...

View file

@ -129,6 +129,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
tool_groups = "tool_groups"
files = "files"
prompts = "prompts"
conversations = "conversations"
# built-in API
inspect = "inspect"

View file

@ -104,6 +104,11 @@ class OpenAIFileDeleteResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Files(Protocol):
"""Files
This API is used to upload documents that can be used with other Llama Stack APIs.
"""
# OpenAI Files API Endpoints
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
@ -113,7 +118,8 @@ class Files(Protocol):
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
) -> OpenAIFileObject:
"""
"""Upload file.
Upload a file that can be used across various endpoints.
The file upload should be a multipart form request with:
@ -137,7 +143,8 @@ class Files(Protocol):
order: Order | None = Order.desc,
purpose: OpenAIFilePurpose | None = None,
) -> ListOpenAIFileResponse:
"""
"""List files.
Returns a list of files that belong to the user's organization.
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
@ -154,7 +161,8 @@ class Files(Protocol):
self,
file_id: str,
) -> OpenAIFileObject:
"""
"""Retrieve file.
Returns information about a specific file.
:param file_id: The ID of the file to use for this request.
@ -168,8 +176,7 @@ class Files(Protocol):
self,
file_id: str,
) -> OpenAIFileDeleteResponse:
"""
Delete a file.
"""Delete file.
:param file_id: The ID of the file to use for this request.
:returns: An OpenAIFileDeleteResponse indicating successful deletion.
@ -182,7 +189,8 @@ class Files(Protocol):
self,
file_id: str,
) -> Response:
"""
"""Retrieve file content.
Returns the contents of the specified file.
:param file_id: The ID of the file to use for this request.

View file

@ -982,45 +982,6 @@ class InferenceProvider(Protocol):
model_store: ModelStore | None = None
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
"""Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param sampling_params: Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model.
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
.. deprecated::
Use tool_config instead.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
.. deprecated::
Use tool_config instead.
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:param tool_config: (Optional) Configuration for tool use.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
"""
...
@webmethod(route="/inference/rerank", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def rerank(
self,
@ -1081,7 +1042,9 @@ class InferenceProvider(Protocol):
# for fill-in-the-middle type completion
suffix: str | None = None,
) -> OpenAICompletion:
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
"""Create completion.
Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for.
@ -1138,7 +1101,9 @@ class InferenceProvider(Protocol):
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
"""Create chat completions.
Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
@ -1182,7 +1147,9 @@ class InferenceProvider(Protocol):
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
"""Generate OpenAI-compatible embeddings for the given input using the specified model.
"""Create embeddings.
Generate OpenAI-compatible embeddings for the given input using the specified model.
:param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings.
@ -1195,7 +1162,9 @@ class InferenceProvider(Protocol):
class Inference(InferenceProvider):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
"""Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
@ -1216,7 +1185,7 @@ class Inference(InferenceProvider):
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
"""List all chat completions.
"""List chat completions.
:param after: The ID of the last chat completion to return.
:param limit: The maximum number of chat completions to return.
@ -1237,10 +1206,11 @@ class Inference(InferenceProvider):
method="GET",
level=LLAMA_STACK_API_V1,
)
async def get_chat_completion(
self, completion_id: str
) -> OpenAICompletionWithInputMessages:
"""Describe a chat completion by its ID.
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
"""Get chat completion.
Describe a chat completion by its ID.
:param completion_id: ID of the chat completion.
:returns: A OpenAICompletionWithInputMessages.

View file

@ -58,9 +58,16 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable
class Inspect(Protocol):
"""Inspect
APIs for inspecting the Llama Stack service, including health status, available API routes with methods and implementing providers.
"""
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
async def list_routes(self) -> ListRoutesResponse:
"""List all available API routes with their methods and implementing providers.
"""List routes.
List all available API routes with their methods and implementing providers.
:returns: Response containing information about all available routes.
"""
@ -68,7 +75,9 @@ class Inspect(Protocol):
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1)
async def health(self) -> HealthInfo:
"""Get the current health status of the service.
"""Get health status.
Get the current health status of the service.
:returns: Health information indicating if the service is operational.
"""
@ -76,7 +85,9 @@ class Inspect(Protocol):
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1)
async def version(self) -> VersionInfo:
"""Get the version of the service.
"""Get version.
Get the version of the service.
:returns: Version information containing the service version number.
"""

View file

@ -124,7 +124,9 @@ class Models(Protocol):
self,
model_id: str,
) -> Model:
"""Get a model by its identifier.
"""Get model.
Get a model by its identifier.
:param model_id: The identifier of the model to get.
:returns: A Model.
@ -140,7 +142,9 @@ class Models(Protocol):
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
"""Register a model.
"""Register model.
Register a model.
:param model_id: The identifier of the model to register.
:param provider_model_id: The identifier of the model in the provider.
@ -156,7 +160,9 @@ class Models(Protocol):
self,
model_id: str,
) -> None:
"""Unregister a model.
"""Unregister model.
Unregister a model.
:param model_id: The identifier of the model to unregister.
"""

View file

@ -94,7 +94,9 @@ class ListPromptsResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Prompts(Protocol):
"""Protocol for prompt management operations."""
"""Prompts
Protocol for prompt management operations."""
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
async def list_prompts(self) -> ListPromptsResponse:
@ -109,7 +111,9 @@ class Prompts(Protocol):
self,
prompt_id: str,
) -> ListPromptsResponse:
"""List all versions of a specific prompt.
"""List prompt versions.
List all versions of a specific prompt.
:param prompt_id: The identifier of the prompt to list versions for.
:returns: A ListPromptsResponse containing all versions of the prompt.
@ -122,7 +126,9 @@ class Prompts(Protocol):
prompt_id: str,
version: int | None = None,
) -> Prompt:
"""Get a prompt by its identifier and optional version.
"""Get prompt.
Get a prompt by its identifier and optional version.
:param prompt_id: The identifier of the prompt to get.
:param version: The version of the prompt to get (defaults to latest).
@ -136,7 +142,9 @@ class Prompts(Protocol):
prompt: str,
variables: list[str] | None = None,
) -> Prompt:
"""Create a new prompt.
"""Create prompt.
Create a new prompt.
:param prompt: The prompt text content with variable placeholders.
:param variables: List of variable names that can be used in the prompt template.
@ -153,7 +161,9 @@ class Prompts(Protocol):
variables: list[str] | None = None,
set_as_default: bool = True,
) -> Prompt:
"""Update an existing prompt (increments version).
"""Update prompt.
Update an existing prompt (increments version).
:param prompt_id: The identifier of the prompt to update.
:param prompt: The updated prompt text content.
@ -169,7 +179,9 @@ class Prompts(Protocol):
self,
prompt_id: str,
) -> None:
"""Delete a prompt.
"""Delete prompt.
Delete a prompt.
:param prompt_id: The identifier of the prompt to delete.
"""
@ -181,7 +193,9 @@ class Prompts(Protocol):
prompt_id: str,
version: int,
) -> Prompt:
"""Set which version of a prompt should be the default in get_prompt (latest).
"""Set prompt version.
Set which version of a prompt should be the default in get_prompt (latest).
:param prompt_id: The identifier of the prompt.
:param version: The version to set as default.

View file

@ -42,13 +42,16 @@ class ListProvidersResponse(BaseModel):
@runtime_checkable
class Providers(Protocol):
"""
"""Providers
Providers API for inspecting, listing, and modifying providers and their configurations.
"""
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
async def list_providers(self) -> ListProvidersResponse:
"""List all available providers.
"""List providers.
List all available providers.
:returns: A ListProvidersResponse containing information about all providers.
"""
@ -56,7 +59,9 @@ class Providers(Protocol):
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get detailed information about a specific provider.
"""Get provider.
Get detailed information about a specific provider.
:param provider_id: The ID of the provider to inspect.
:returns: A ProviderInfo object containing the provider's details.

View file

@ -96,6 +96,11 @@ class ShieldStore(Protocol):
@runtime_checkable
@trace_protocol
class Safety(Protocol):
"""Safety
OpenAI-compatible Moderations API.
"""
shield_store: ShieldStore
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
@ -105,7 +110,9 @@ class Safety(Protocol):
messages: list[Message],
params: dict[str, Any],
) -> RunShieldResponse:
"""Run a shield.
"""Run shield.
Run a shield.
:param shield_id: The identifier of the shield to run.
:param messages: The messages to run the shield on.
@ -117,7 +124,9 @@ class Safety(Protocol):
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
"""Classifies if text and/or image inputs are potentially harmful.
"""Create moderation.
Classifies if text and/or image inputs are potentially harmful.
:param input: Input (or inputs) to classify.
Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models.
:param model: The content moderation model you would like to use.

View file

@ -6,11 +6,18 @@
import argparse
import os
import ssl
import subprocess
from pathlib import Path
import uvicorn
import yaml
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.subcommand import Subcommand
from llama_stack.core.datatypes import LoggingConfig, StackRunConfig
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars, validate_env_pair
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -146,23 +153,7 @@ class StackRun(Subcommand):
# using the current environment packages.
if not image_type and not image_name:
logger.info("No image type or image name provided. Assuming environment packages.")
from llama_stack.core.server.server import main as server_main
# Build the server args from the current args passed to the CLI
server_args = argparse.Namespace()
for arg in vars(args):
# If this is a function, avoid passing it
# "args" contains:
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
if callable(getattr(args, arg)):
continue
if arg == "config":
server_args.config = str(config_file)
else:
setattr(server_args, arg, getattr(args, arg))
# Run the server
server_main(server_args)
self._uvicorn_run(config_file, args)
else:
run_args = formulate_run_args(image_type, image_name)
@ -184,6 +175,76 @@ class StackRun(Subcommand):
run_command(run_args)
def _uvicorn_run(self, config_file: Path | None, args: argparse.Namespace) -> None:
if not config_file:
self.parser.error("Config file is required")
# Set environment variables if provided
if args.env:
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
self.parser.error(f"Invalid environment variable format: {env_pair}")
config_file = resolve_config_or_distro(str(config_file), Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
port = args.port or config.server.port
host = config.server.host or ["::", "0.0.0.0"]
# Set the config file in environment so create_app can find it
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
uvicorn_config = {
"factory": True,
"host": host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
}
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
uvicorn_config["ssl_keyfile"] = config.server.tls_keyfile
uvicorn_config["ssl_certfile"] = config.server.tls_certfile
if config.server.tls_cafile:
uvicorn_config["ssl_ca_certs"] = config.server.tls_cafile
uvicorn_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
logger.info(f"Listening on {host}:{port}")
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
# stack trace when using Ctrl+C or kill -2 (SIGINT).
# SIGTERM (kill -15) works fine without this because Python doesn't
# have a default handler for it.
#
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
uvicorn.run("llama_stack.core.server.server:create_app", **uvicorn_config)
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
def _start_ui_development_server(self, stack_server_port: int):
logger.info("Attempting to start UI development server...")
# Check if npm is available

View file

@ -324,14 +324,14 @@ fi
RUN pip uninstall -y uv
EOF
# If a run config is provided, we use the --config flag
# If a run config is provided, we use the llama stack CLI
if [[ -n "$run_config" ]]; then
add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$RUN_CONFIG_PATH"]
ENTRYPOINT ["llama", "stack", "run", "$RUN_CONFIG_PATH"]
EOF
elif [[ "$distro_or_config" != *.yaml ]]; then
add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$distro_or_config"]
ENTRYPOINT ["llama", "stack", "run", "$distro_or_config"]
EOF
fi

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,306 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import secrets
import time
from typing import Any
from openai import NOT_GIVEN
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationDeletedResource,
ConversationItem,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
Metadata,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import (
SqliteSqlStoreConfig,
SqlStoreConfig,
sqlstore_impl,
)
logger = get_logger(name=__name__, category="openai::conversations")
class ConversationServiceConfig(BaseModel):
"""Configuration for the built-in conversation service.
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
:param policy: Access control rules
"""
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
)
policy: list[AccessRule] = []
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
"""Get the conversation service implementation."""
impl = ConversationServiceImpl(config, deps)
await impl.initialize()
return impl
class ConversationServiceImpl(Conversations):
"""Built-in conversation service implementation using AuthorizedSqlStore."""
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.policy = config.policy
base_sql_store = sqlstore_impl(config.conversations_store)
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
async def initialize(self) -> None:
"""Initialize the store and create tables."""
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
await self.sql_store.create_table(
"openai_conversations",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"created_at": ColumnType.INTEGER,
"items": ColumnType.JSON,
"metadata": ColumnType.JSON,
},
)
await self.sql_store.create_table(
"conversation_items",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"conversation_id": ColumnType.STRING,
"created_at": ColumnType.INTEGER,
"item_data": ColumnType.JSON,
},
)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation."""
random_bytes = secrets.token_bytes(24)
conversation_id = f"conv_{random_bytes.hex()}"
created_at = int(time.time())
record_data = {
"id": conversation_id,
"created_at": created_at,
"items": [],
"metadata": metadata,
}
await self.sql_store.insert(
table="openai_conversations",
data=record_data,
)
if items:
item_records = []
for item in items:
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
item_record = {
"id": item_id,
"conversation_id": conversation_id,
"created_at": created_at,
"item_data": item_dict,
}
item_records.append(item_record)
await self.sql_store.insert(table="conversation_items", data=item_records)
conversation = Conversation(
id=conversation_id,
created_at=created_at,
metadata=metadata,
object="conversation",
)
logger.info(f"Created conversation {conversation_id}")
return conversation
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID."""
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
if record is None:
raise ValueError(f"Conversation {conversation_id} not found")
return Conversation(
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID"""
await self.sql_store.update(
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
)
return await self.get_conversation(conversation_id)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID."""
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
logger.info(f"Deleted conversation {conversation_id}")
return ConversationDeletedResource(id=conversation_id)
def _validate_conversation_id(self, conversation_id: str) -> None:
"""Validate conversation ID format."""
if not conversation_id.startswith("conv_"):
raise ValueError(
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
)
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
"""Get existing item ID or generate one if missing."""
if item.id is None:
random_bytes = secrets.token_bytes(24)
if item.type == "message":
item_id = f"msg_{random_bytes.hex()}"
else:
item_id = f"item_{random_bytes.hex()}"
item_dict["id"] = item_id
return item_id
return item.id
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
"""Validate conversation ID and return the conversation if it exists."""
self._validate_conversation_id(conversation_id)
return await self.get_conversation(conversation_id)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create (add) items to a conversation."""
await self._get_validated_conversation(conversation_id)
created_items = []
created_at = int(time.time())
for item in items:
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
item_record = {
"id": item_id,
"conversation_id": conversation_id,
"created_at": created_at,
"item_data": item_dict,
}
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
try:
await self.sql_store.insert(table="conversation_items", data=item_record)
except Exception:
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_items",
data={"created_at": created_at, "item_data": item_dict},
where={"id": item_id},
)
created_items.append(item_dict)
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
# Convert created items (dicts) to proper ConversationItem types
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
return ConversationItemList(
data=response_items,
first_id=created_items[0]["id"] if created_items else None,
last_id=created_items[-1]["id"] if created_items else None,
has_more=False,
)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
# Get item from conversation_items table
record = await self.sql_store.fetch_one(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
if record is None:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
return adapter.validate_python(record["item_data"])
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
"""List items in the conversation."""
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
records = result.data
if order != NOT_GIVEN and order == "asc":
records.sort(key=lambda x: x["created_at"])
else:
records.sort(key=lambda x: x["created_at"], reverse=True)
actual_limit = 20
if limit != NOT_GIVEN and isinstance(limit, int):
actual_limit = limit
records = records[:actual_limit]
items = [record["item_data"] for record in records]
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
first_id = response_items[0].id if response_items else None
last_id = response_items[-1].id if response_items else None
return ConversationItemList(
data=response_items,
first_id=first_id,
last_id=last_id,
has_more=False,
)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
_ = await self._get_validated_conversation(conversation_id)
record = await self.sql_store.fetch_one(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
if record is None:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
await self.sql_store.delete(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
return ConversationItemDeletedResource(id=item_id)

View file

@ -475,6 +475,13 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
If not specified, a default SQLite store will be used.""",
)
conversations_store: SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the conversations API.
If not specified, a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)

View file

@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import (
logger = get_logger(name=__name__, category="core")
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
def stack_apis() -> list[Api]:
@ -243,6 +243,7 @@ def get_external_providers_from_module(
spec = module.get_provider_spec()
else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
# in the case we are building we CANNOT import this module of course because it has not been installed.
spec = ProviderSpec(
api=Api(provider_api),
provider_type=provider.provider_type,
@ -251,9 +252,20 @@ def get_external_providers_from_module(
config_class="",
)
provider_type = provider.provider_type
# in the case we are building we CANNOT import this module of course because it has not been installed.
# return a partially filled out spec that the build script will populate.
registry[Api(provider_api)][provider_type] = spec
if isinstance(spec, list):
# optionally allow people to pass inline and remote provider specs as a returned list.
# with the old method, users could pass in directories of specs using overlapping code
# we want to ensure we preserve that flexibility in this method.
logger.info(
f"Detected a list of external provider specs from {provider.module} adding all to the registry"
)
for provider_spec in spec:
if provider_spec.provider_type != provider.provider_type:
continue
logger.info(f"Adding {provider.provider_type} to registry")
registry[Api(provider_api)][provider.provider_type] = provider_spec
else:
registry[Api(provider_api)][provider_type] = spec
except ModuleNotFoundError as exc:
raise ValueError(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"

View file

@ -374,6 +374,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {}
body |= options.json_data or {}
# Merge extra_json parameters (extra_body from SDK is converted to extra_json)
if hasattr(options, "extra_json") and options.extra_json:
body |= options.extra_json
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params

View file

@ -10,6 +10,7 @@ from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
@ -96,6 +97,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.tool_runtime: ToolRuntime,
Api.files: Files,
Api.prompts: Prompts,
Api.conversations: Conversations,
}
if external_apis:

View file

@ -19,7 +19,6 @@ from llama_stack.apis.inference import (
CompletionMessage,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
@ -34,12 +33,7 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
Order,
ResponseFormat,
SamplingParams,
StopReason,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
@ -193,102 +187,6 @@ class InferenceRouter(Inference):
raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = None,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id, ModelType.llm)
if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
if (
tool_prompt_format
and tool_prompt_format != tool_config.tool_prompt_format
):
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else:
params = {}
if tool_choice:
params["tool_choice"] = tool_choice
if tool_prompt_format:
params["tool_prompt_format"] = tool_prompt_format
tool_config = ToolConfig(**params)
tools = tools or []
if tool_config.tool_choice == ToolChoice.none:
tools = []
elif tool_config.tool_choice == ToolChoice.auto:
pass
elif tool_config.tool_choice == ToolChoice.required:
pass
else:
# verify tool_choice is one of the tools
tool_names = [
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
for t in tools
]
if tool_config.tool_choice not in tool_names:
raise ValueError(
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
)
params = dict(
model_id=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(
messages, tool_config.tool_prompt_format
)
if stream:
response_stream = await provider.chat_completion(**params)
return self.stream_tokens_and_compute_metrics(
response=response_stream,
prompt_tokens=prompt_tokens,
model=model,
tool_prompt_format=tool_config.tool_prompt_format,
)
response = await provider.chat_completion(**params)
metrics = await self.count_tokens_and_compute_metrics(
response=response,
prompt_tokens=prompt_tokens,
model=model,
tool_prompt_format=tool_config.tool_prompt_format,
)
# these metrics will show up in the client response.
response.metrics = (
metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
)
return response
async def openai_completion(
self,
model: str,

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import asyncio
import concurrent.futures
import functools
@ -12,7 +11,6 @@ import inspect
import json
import logging # allow-direct-logging
import os
import ssl
import sys
import traceback
import warnings
@ -35,7 +33,6 @@ from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
from llama_stack.core.access_control.access_control import AccessDeniedError
from llama_stack.core.datatypes import (
AuthenticationRequiredError,
@ -55,7 +52,6 @@ from llama_stack.core.stack import (
Stack,
cast_image_name_to_string,
replace_env_vars,
validate_env_pair,
)
from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
@ -333,23 +329,18 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send)
def create_app(
config_file: str | None = None,
env_vars: list[str] | None = None,
) -> StackApp:
def create_app() -> StackApp:
"""Create and configure the FastAPI application.
Args:
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
env_vars: List of environment variables in KEY=value format.
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
This factory function reads configuration from environment variables:
- LLAMA_STACK_CONFIG: Path to config file (required)
Returns:
Configured StackApp instance.
"""
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
config_file = os.getenv("LLAMA_STACK_CONFIG")
if config_file is None:
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
raise ValueError("LLAMA_STACK_CONFIG environment variable is required")
config_file = resolve_config_or_distro(config_file, Mode.RUN)
@ -361,16 +352,6 @@ def create_app(
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="core::server", config=logger_config)
if env_vars:
for env_pair in env_vars:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config))
@ -451,6 +432,7 @@ def create_app(
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
apis_to_serve.add("prompts")
apis_to_serve.add("conversations")
for api_str in apis_to_serve:
api = Api(api_str)
@ -493,101 +475,6 @@ def create_app(
return app
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
add_config_distro_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
config_or_distro = get_config_from_args(args)
try:
app = create_app(
config_file=config_or_distro,
env_vars=args.env,
)
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
sys.exit(1)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
import uvicorn
# Configure SSL if certificates are provided
port = args.port or config.server.port
ssl_config = None
keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile
if keyfile and certfile:
ssl_config = {
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
if config.server.tls_cafile:
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = config.server.host or ["::", "0.0.0.0"]
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,
"host": listen_host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
}
if ssl_config:
uvicorn_config.update(ssl_config)
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
# stack trace when using Ctrl+C or kill -2 (SIGINT).
# SIGTERM (kill -15) works fine without this because Python doesn't
# have a default handler for it.
#
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:")
@ -614,7 +501,3 @@ def remove_disabled_providers(obj):
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
else:
return obj
if __name__ == "__main__":
main()

View file

@ -15,6 +15,7 @@ import yaml
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
@ -34,6 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
@ -73,6 +75,7 @@ class LlamaStack(
RAGToolRuntime,
Files,
Prompts,
Conversations,
):
pass
@ -312,6 +315,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
)
impls[Api.prompts] = prompts_impl
conversations_impl = ConversationServiceImpl(
ConversationServiceConfig(run_config=run_config),
deps=impls,
)
impls[Api.conversations] = conversations_impl
class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
@ -342,6 +351,8 @@ class Stack:
if Api.prompts in impls:
await impls[Api.prompts].initialize()
if Api.conversations in impls:
await impls[Api.conversations].initialize()
await register_resources(self.run_config, impls)

View file

@ -116,7 +116,7 @@ if [[ "$env_type" == "venv" ]]; then
yaml_config_arg=""
fi
$PYTHON_BINARY -m llama_stack.core.server.server \
llama stack run \
$yaml_config_arg \
--port "$port" \
$env_vars \

View file

@ -128,7 +128,7 @@ def strip_rich_markup(text):
class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs):
kwargs["console"] = Console(width=150)
kwargs["console"] = Console()
super().__init__(*args, **kwargs)
def emit(self, record):

View file

@ -9,7 +9,7 @@ from pathlib import Path
from llama_stack.log import get_logger
logger = get_logger(__name__, "tokenizer_utils")
logger = get_logger(__name__, "models")
def load_bpe_file(model_path: Path) -> dict[bytes, int]:

View file

@ -329,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
shields: list | None = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input,
@ -342,6 +343,7 @@ class MetaReferenceAgentsImpl(Agents):
tools,
include,
max_infer_iters,
shields,
)
async def list_openai_responses(

View file

@ -8,7 +8,7 @@ import time
import uuid
from collections.abc import AsyncIterator
from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
@ -26,12 +26,16 @@ from llama_stack.apis.agents.openai_responses import (
)
from llama_stack.apis.inference import (
Inference,
OpenAIMessageParam,
OpenAISystemMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
@ -72,26 +76,48 @@ class OpenAIResponsesImpl:
async def _prepend_previous_response(
self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None = None,
previous_response: _OpenAIResponseObjectWithInputAndMessages,
):
new_input_items = previous_response.input.copy()
new_input_items.extend(previous_response.output)
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
return new_input_items
async def _process_input_with_previous_response(
self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None,
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
"""Process input with optional previous response context.
Returns:
tuple: (all_input for storage, messages for chat completion)
"""
if previous_response_id:
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
await self.responses_store.get_response_object(previous_response_id)
)
all_input = await self._prepend_previous_response(input, previous_response)
# previous response input items
new_input_items = previous_response_with_input.input
# previous response output items
new_input_items.extend(previous_response_with_input.output)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
if previous_response.messages:
# Use stored messages directly and convert only new input
message_adapter = TypeAdapter(list[OpenAIMessageParam])
messages = message_adapter.validate_python(previous_response.messages)
new_messages = await convert_response_input_to_chat_messages(input)
messages.extend(new_messages)
else:
new_input_items.extend(input)
# Backward compatibility: reconstruct from inputs
messages = await convert_response_input_to_chat_messages(all_input)
else:
all_input = input
messages = await convert_response_input_to_chat_messages(input)
input = new_input_items
return input
return all_input, messages
async def _prepend_instructions(self, messages, instructions):
if instructions:
@ -102,7 +128,7 @@ class OpenAIResponsesImpl:
response_id: str,
) -> OpenAIResponseObject:
response_with_input = await self.responses_store.get_response_object(response_id)
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
return response_with_input.to_response_object()
async def list_openai_responses(
self,
@ -138,6 +164,7 @@ class OpenAIResponsesImpl:
self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
@ -165,6 +192,7 @@ class OpenAIResponsesImpl:
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
messages=messages,
)
async def create_openai_response(
@ -180,10 +208,15 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
shields: list | None = None,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
# Shields parameter received via extra_body - not yet implemented
if shields is not None:
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
stream_gen = self._create_streaming_response(
input=input,
model=model,
@ -224,8 +257,7 @@ class OpenAIResponsesImpl:
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await convert_response_input_to_chat_messages(input)
all_input, messages = await self._process_input_with_previous_response(input, previous_response_id)
await self._prepend_instructions(messages, instructions)
# Structured outputs
@ -265,7 +297,8 @@ class OpenAIResponsesImpl:
if store and final_response:
await self._store_response(
response=final_response,
input=input,
input=all_input,
messages=orchestrator.final_messages,
)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:

View file

@ -43,6 +43,7 @@ from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
)
from llama_stack.log import get_logger
@ -94,6 +95,8 @@ class StreamingResponseOrchestrator:
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages
@ -183,6 +186,8 @@ class StreamingResponseOrchestrator:
messages = next_turn_messages
self.final_messages = messages.copy() + [current_response.choices[0].message]
# Create final response
final_response = OpenAIResponseObject(
created_at=self.created_at,

View file

@ -5,37 +5,17 @@
# the root directory of this source tree.
import asyncio
import os
import sys
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
InferenceProvider,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
TokenLogProbs,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
UserMessage,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
@ -53,13 +33,6 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
convert_request_to_raw,
)
from .config import MetaReferenceInferenceConfig
from .generators import LlamaGenerator
@ -76,7 +49,6 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
class MetaReferenceInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
InferenceProvider,
ModelsProtocolPrivate,
@ -161,10 +133,10 @@ class MetaReferenceInferenceImpl(
self.llama_model = llama_model
log.info("Warming up...")
await self.chat_completion(
model_id=model_id,
messages=[UserMessage(content="Hi how are you?")],
sampling_params=SamplingParams(max_tokens=20),
await self.openai_chat_completion(
model=model_id,
messages=[{"role": "user", "content": "Hi how are you?"}],
max_tokens=20,
)
log.info("Warmed up!")
@ -176,242 +148,30 @@ class MetaReferenceInferenceImpl(
elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def chat_completion(
async def openai_chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream:
return self._stream_chat_completion(request)
else:
results = await self._nonstream_chat_completion([request])
return results[0]
async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest]
) -> list[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
states = [ItemState() for _ in request_batch]
for token_results in self.generator.chat_completion(request_batch):
first = token_results[0]
if not first.finished and not first.ignore_token:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
results = []
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
results.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=state.logprobs if first_request.logprobs else None,
)
)
return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
tokens = []
logprobs = []
stop_reason = None
ipython = False
for token_results in self.generator.chat_completion([request]):
token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
tool_call=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = TextDelta(text=text)
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")

View file

@ -4,21 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import (
InferenceProvider,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAICompletion
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
@ -69,21 +67,6 @@ class SentenceTransformersInferenceImpl(
async def unregister_model(self, model_id: str) -> None:
pass
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def openai_completion(
self,
# Standard OpenAI completion parameters
@ -110,6 +93,32 @@ class SentenceTransformersInferenceImpl(
# for fill-in-the-middle type completion
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError(
"OpenAI completion not supported by sentence transformers provider"
)
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")

View file

@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="cerebras",
provider_type="remote::cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
@ -169,7 +167,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="openai",
provider_type="remote::openai",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="anthropic",
provider_type="remote::anthropic",
pip_packages=["litellm"],
pip_packages=["anthropic"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
@ -189,9 +187,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="gemini",
provider_type="remote::gemini",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
@ -202,7 +198,6 @@ def available_providers() -> list[ProviderSpec]:
adapter_type="vertexai",
provider_type="remote::vertexai",
pip_packages=[
"litellm",
"google-cloud-aiplatform",
],
module="llama_stack.providers.remote.inference.vertexai",
@ -233,9 +228,7 @@ Available Models:
api=Api.inference,
adapter_type="groq",
provider_type="remote::groq",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
@ -245,7 +238,7 @@ Available Models:
api=Api.inference,
adapter_type="llama-openai-compat",
provider_type="remote::llama-openai-compat",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
@ -255,9 +248,7 @@ Available Models:
api=Api.inference,
adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
@ -287,7 +278,7 @@ Available Models:
api=Api.inference,
provider_type="remote::azure",
adapter_type="azure",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.azure",
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",

View file

@ -500,7 +500,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
api=Api.vector_io,
adapter_type="weaviate",
provider_type="remote::weaviate",
pip_packages=["weaviate-client"],
pip_packages=["weaviate-client>=4.16.5"],
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",

View file

@ -10,6 +10,6 @@ from .config import AnthropicConfig
async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter
impl = AnthropicInferenceAdapter(config)
impl = AnthropicInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,13 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from collections.abc import Iterable
from anthropic import AsyncAnthropic
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AnthropicConfig
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class AnthropicInferenceAdapter(OpenAIMixin):
config: AnthropicConfig
provider_data_api_key_field: str = "anthropic_api_key"
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
# TODO: add support for voyageai, which is where these models are hosted
# embedding_model_metadata = {
@ -23,22 +29,11 @@ class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# }
def __init__(self, config: AnthropicConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",
)
self.config = config
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self):
return "https://api.anthropic.com/v1"
async def list_provider_model_ids(self) -> Iterable[str]:
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class AnthropicProviderDataValidator(BaseModel):
@json_schema_type
class AnthropicConfig(BaseModel):
class AnthropicConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Anthropic models",

View file

@ -10,6 +10,6 @@ from .config import AzureConfig
async def get_adapter_impl(config: AzureConfig, _deps):
from .azure import AzureInferenceAdapter
impl = AzureInferenceAdapter(config)
impl = AzureInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,31 +4,20 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from urllib.parse import urljoin
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: AzureConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="azure",
api_key_from_config=config.api_key.get_secret_value(),
provider_data_api_key_field="azure_api_key",
openai_compat_api_base=str(config.api_base),
)
self.config = config
class AzureInferenceAdapter(OpenAIMixin):
config: AzureConfig
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
provider_data_api_key_field: str = "azure_api_key"
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
def get_base_url(self) -> str:
"""
@ -37,26 +26,3 @@ class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Returns the Azure API base URL from the configuration.
"""
return urljoin(str(self.config.api_base), "/openai/v1")
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add Azure specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "azure_api_key", None):
params["api_key"] = provider_data.azure_api_key
if getattr(provider_data, "azure_api_base", None):
params["api_base"] = provider_data.azure_api_base
if getattr(provider_data, "azure_api_version", None):
params["api_version"] = provider_data.azure_api_version
if getattr(provider_data, "azure_api_type", None):
params["api_type"] = provider_data.azure_api_type
else:
params["api_key"] = self.config.api_key.get_secret_value()
params["api_base"] = str(self.config.api_base)
params["api_version"] = self.config.api_version
params["api_type"] = self.config.api_type
return params

View file

@ -9,6 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -30,7 +31,7 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type
class AzureConfig(BaseModel):
class AzureConfig(RemoteInferenceProviderConfig):
api_key: SecretStr = Field(
description="Azure API key for Azure",
)

View file

@ -5,39 +5,30 @@
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncIterator
from typing import Any
from botocore.client import BaseClient
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAICompletion
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_strategy_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
@ -86,7 +77,6 @@ def _to_inference_profile_id(model_id: str, region: str = None) -> str:
class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
@ -106,71 +96,6 @@ class BedrockInferenceAdapter(
if self._client is not None:
self._client.close()
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model(**params)
chunk = next(res["body"])
result = json.loads(chunk.decode("utf-8"))
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"],
text=result["generation"],
)
response = OpenAICompatCompletionResponse(choices=[choice])
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model_with_response_stream(**params)
event_stream = res["body"]
async def _generate_and_convert_to_openai_compat():
for chunk in event_stream:
chunk = chunk["chunk"]["bytes"]
result = json.loads(chunk.decode("utf-8"))
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"],
text=result["generation"],
)
yield OpenAICompatCompletionResponse(choices=[choice])
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
bedrock_model = request.model
@ -235,3 +160,31 @@ class BedrockInferenceAdapter(
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")

View file

@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
impl = CerebrasInferenceAdapter(config)
impl = CerebrasInferenceAdapter(config=config)
await impl.initialize()

View file

@ -4,53 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from urllib.parse import urljoin
from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from .config import CerebrasImplConfig
class CerebrasInferenceAdapter(
OpenAIMixin,
ModelRegistryHelper,
Inference,
):
def __init__(self, config: CerebrasImplConfig) -> None:
self.config = config
# TODO: make this use provider data, etc. like other providers
self._cerebras_client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
)
class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
@ -58,86 +21,6 @@ class CerebrasInferenceAdapter(
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self._cerebras_client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self._cerebras_client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest) -> dict:
if request.sampling_params and isinstance(
request.sampling_params.strategy, TopKSamplingStrategy
):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model)
)
else:
raise ValueError(f"Unknown request type {type(request)}")
return {
"model": request.model,
"prompt": prompt,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def openai_embeddings(
self,
model: str,

View file

@ -7,21 +7,22 @@
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai"
@json_schema_type
class CerebrasImplConfig(BaseModel):
class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
api_key: SecretStr = Field(
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")), # type: ignore[arg-type]
description="Cerebras API Key",
)

View file

@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
impl = DatabricksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,19 +6,20 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class DatabricksImplConfig(BaseModel):
url: str = Field(
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: SecretStr = Field(
default=SecretStr(None),
default=SecretStr(None), # type: ignore[arg-type]
description="The Databricks API token",
)

View file

@ -4,27 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from collections.abc import Iterable
from typing import Any
from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
Inference,
LogProbConfig,
Message,
Model,
OpenAICompletion,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.apis.inference import OpenAICompletion
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -33,30 +18,31 @@ from .config import DatabricksImplConfig
logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(
OpenAIMixin,
Inference,
):
class DatabricksInferenceAdapter(OpenAIMixin):
config: DatabricksImplConfig
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
}
def __init__(self, config: DatabricksImplConfig) -> None:
self.config = config
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
async def initialize(self) -> None:
return
async def list_provider_model_ids(self) -> Iterable[str]:
return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]
async def shutdown(self) -> None:
pass
async def should_refresh_models(self) -> bool:
return False
async def openai_completion(
self,
@ -82,47 +68,3 @@ class DatabricksInferenceAdapter(
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
raise NotImplementedError()
async def list_models(self) -> list[Model] | None:
self._model_cache = {} # from OpenAIMixin
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
endpoints = ws_client.serving_endpoints.list()
for endpoint in endpoints:
model = Model(
provider_id=self.__provider_id__,
provider_resource_id=endpoint.name,
identifier=endpoint.name,
)
if endpoint.task == "llm/v1/chat":
model.model_type = ModelType.llm # this is redundant, but informative
elif endpoint.task == "llm/v1/embeddings":
if endpoint.name not in self.embedding_model_metadata:
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
continue
model.model_type = ModelType.embedding
model.metadata = self.embedding_model_metadata[endpoint.name]
else:
logger.warning(f"Unknown model type, skipping: {endpoint}")
continue
self._model_cache[endpoint.name] = model
return list(self._model_cache.values())
async def should_refresh_models(self) -> bool:
return False

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: FireworksImplConfig, _deps):
from .fireworks import FireworksInferenceAdapter
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
impl = FireworksInferenceAdapter(config)
impl = FireworksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,195 +4,27 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from fireworks.client import Fireworks
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
request_has_media,
)
from .config import FireworksImplConfig
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
embedding_model_metadata = {
class FireworksInferenceAdapter(OpenAIMixin):
config: FireworksImplConfig
embedding_model_metadata: dict[str, dict[str, int]] = {
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
}
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self)
self.config = config
self.allowed_models = config.allowed_models
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
provider_data_api_key_field: str = "fireworks_api_key"
def get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
return config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-Provider-Data as { "fireworks_api_key": <your api key>}'
)
return provider_data.fireworks_api_key
return self.config.api_key.get_secret_value() if self.config.api_key else None # type: ignore[return-value]
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks:
fireworks_api_key = self.get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
"""Remove BOS token as Fireworks automatically prepends it"""
if prompt.startswith("<|begin_of_text|>"):
return prompt[len("<|begin_of_text|>") :]
return prompt
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
if "messages" in params:
stream = self._get_client().chat.completions.acreate(**params)
else:
stream = self._get_client().completion.acreate(**params)
async for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _build_options(
self,
sampling_params: SamplingParams | None,
fmt: ResponseFormat | None,
logprobs: LogProbConfig | None,
) -> dict:
options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
options["logprobs"] = logprobs.top_k
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
raise ValueError("Required range: 0 < top_k < 5")
return options
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
# TODO: tools are never added to the request, so we need to add them here
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
# Fireworks always prepends with BOS
if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
params = {
"model": request.model,
**input_dict,
"stream": bool(request.stream),
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logger.debug(f"params to fireworks: {params}")
return params

View file

@ -10,6 +10,6 @@ from .config import GeminiConfig
async def get_adapter_impl(config: GeminiConfig, _deps):
from .gemini import GeminiInferenceAdapter
impl = GeminiInferenceAdapter(config)
impl = GeminiInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,7 +20,7 @@ class GeminiProviderDataValidator(BaseModel):
@json_schema_type
class GeminiConfig(BaseModel):
class GeminiConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default=None,
description="API key for Gemini models",

View file

@ -4,33 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import GeminiConfig
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
embedding_model_metadata = {
class GeminiInferenceAdapter(OpenAIMixin):
config: GeminiConfig
provider_data_api_key_field: str = "gemini_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
}
def __init__(self, config: GeminiConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="gemini",
api_key_from_config=config.api_key,
provider_data_api_key_field="gemini_api_key",
)
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str:
return self.config.api_key or ""
def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/"
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()

View file

@ -11,5 +11,5 @@ async def get_adapter_impl(config: GroqConfig, _deps):
# import dynamically so the import is used only when it is needed
from .groq import GroqInferenceAdapter
adapter = GroqInferenceAdapter(config)
adapter = GroqInferenceAdapter(config=config)
return adapter

Some files were not shown because too many files have changed in this diff Show more