Merge branch 'meta-llama:main' into main

This commit is contained in:
Vaishnavi Hire 2025-04-23 11:37:53 -04:00 committed by GitHub
commit 11810c9a03
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
172 changed files with 21921 additions and 10564 deletions

View file

@ -34,22 +34,20 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv
uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
with:
python-version: "3.10"
- name: Install Ollama
- name: Install and start Ollama
run: |
# the ollama installer also starts the ollama service
curl -fsSL https://ollama.com/install.sh | sh
- name: Pull Ollama image
run: |
# TODO: cache the model. OLLAMA_MODELS defaults to ~ollama/.ollama/models.
ollama pull llama3.2:3b-instruct-fp16
- name: Start Ollama in background
run: |
nohup ollama run llama3.2:3b-instruct-fp16 > ollama.log 2>&1 &
- name: Set Up Environment and Install Dependencies
run: |
uv sync --extra dev --extra test
@ -61,21 +59,6 @@ jobs:
uv pip install -e .
llama stack build --template ollama --image-type venv
- name: Wait for Ollama to start
run: |
echo "Waiting for Ollama..."
for i in {1..30}; do
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
echo "Ollama is running!"
exit 0
fi
sleep 1
done
echo "Ollama failed to start"
ollama ps
ollama.log
exit 1
- name: Start Llama Stack server in background
if: matrix.client-type == 'http'
env:
@ -99,6 +82,17 @@ jobs:
cat server.log
exit 1
- name: Verify Ollama status is OK
if: matrix.client-type == 'http'
run: |
echo "Verifying Ollama status..."
ollama_status=$(curl -s -L http://127.0.0.1:8321/v1/providers/ollama|jq --raw-output .health.status)
echo "Ollama status: $ollama_status"
if [ "$ollama_status" != "OK" ]; then
echo "Ollama health check failed"
exit 1
fi
- name: Run Integration Tests
env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"

View file

@ -31,3 +31,12 @@ jobs:
- name: Verify if there are any diff files after pre-commit
run: |
git diff --exit-code || (echo "There are uncommitted changes, run pre-commit locally and commit again" && exit 1)
- name: Verify if there are any new files after pre-commit
run: |
unstaged_files=$(git ls-files --others --exclude-standard)
if [ -n "$unstaged_files" ]; then
echo "There are uncommitted new files, run pre-commit locally and commit again"
echo "$unstaged_files"
exit 1
fi

View file

@ -56,7 +56,7 @@ jobs:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
with:
python-version: "3.10"
@ -81,3 +81,29 @@ jobs:
run: |
source test/bin/activate
uv pip list
build-single-provider:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
with:
python-version: "3.10"
- name: Install LlamaStack
run: |
uv venv
source .venv/bin/activate
uv pip install -e .
- name: Build a single provider
run: |
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama

View file

@ -9,6 +9,11 @@ on:
jobs:
test-external-providers:
runs-on: ubuntu-latest
strategy:
matrix:
image-type: [venv]
# We don't do container yet, it's tricky to install a package from the host into the
# container and point 'uv pip install' to the correct path...
steps:
- name: Checkout repository
uses: actions/checkout@v4
@ -35,17 +40,25 @@ jobs:
uv sync --extra dev --extra test
uv pip install -e .
- name: Install Ollama custom provider
- name: Apply image type to config file
run: |
yq -i '.image_type = "${{ matrix.image-type }}"' tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
cat tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
- name: Setup directory for Ollama custom provider
run: |
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
uv pip install tests/external-provider/llama-stack-provider-ollama
- name: Create provider configuration
run: |
mkdir -p /tmp/providers.d/remote/inference
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
- name: Build distro from config file
run: |
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
- name: Wait for Ollama to start
run: |
echo "Waiting for Ollama..."
@ -62,11 +75,13 @@ jobs:
exit 1
- name: Start Llama Stack server in background
if: ${{ matrix.image-type }} == 'venv'
env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: |
source .venv/bin/activate
nohup uv run llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type venv > server.log 2>&1 &
source ci-test/bin/activate
uv run pip list
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready
run: |

View file

@ -38,7 +38,7 @@ jobs:
with:
python-version: ${{ matrix.python }}
- uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
- uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
with:
python-version: ${{ matrix.python }}
enable-cache: false

View file

@ -41,7 +41,7 @@ jobs:
python-version: '3.11'
- name: Install the latest version of uv
uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
- name: Sync with uv
run: uv sync --extra docs

View file

@ -9,15 +9,16 @@
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
### ✨🎉 Llama 4 Support 🎉✨
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
You can now run Llama 4 models on Llama Stack.
<details>
<summary>👋 Click here to see how to run Llama 4 models on Llama Stack </summary>
\
*Note you need 8xH100 GPU-host to run these models*
```bash
pip install -U llama_stack
@ -67,6 +68,9 @@ print(f"Assistant> {response.completion_message.content}")
As more providers start supporting Llama 4, you can use them in Llama Stack as well. We are adding to the list. Stay tuned!
</details>
### Overview
Llama Stack standardizes the core building blocks that simplify AI application development. It codifies best practices across the Llama ecosystem. More specifically, it provides

View file

@ -16,3 +16,14 @@
.hide-title h1 {
display: none;
}
h2, h3, h4 {
font-weight: normal;
}
html[data-theme="dark"] .rst-content div[class^="highlight"] {
background-color: #0b0b0b;
}
pre {
white-space: pre-wrap !important;
word-break: break-all;
}

32
docs/_static/js/detect_theme.js vendored Normal file
View file

@ -0,0 +1,32 @@
document.addEventListener("DOMContentLoaded", function () {
const prefersDark = window.matchMedia("(prefers-color-scheme: dark)").matches;
const htmlElement = document.documentElement;
// Check if theme is saved in localStorage
const savedTheme = localStorage.getItem("sphinx-rtd-theme");
if (savedTheme) {
// Use the saved theme preference
htmlElement.setAttribute("data-theme", savedTheme);
document.body.classList.toggle("dark", savedTheme === "dark");
} else {
// Fall back to system preference
const theme = prefersDark ? "dark" : "light";
htmlElement.setAttribute("data-theme", theme);
document.body.classList.toggle("dark", theme === "dark");
// Save initial preference
localStorage.setItem("sphinx-rtd-theme", theme);
}
// Listen for theme changes from the existing toggle
const observer = new MutationObserver(function(mutations) {
mutations.forEach(function(mutation) {
if (mutation.attributeName === "data-theme") {
const currentTheme = htmlElement.getAttribute("data-theme");
localStorage.setItem("sphinx-rtd-theme", currentTheme);
}
});
});
observer.observe(htmlElement, { attributes: true });
});

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -68,7 +68,8 @@ chunks_response = client.vector_io.query(
### Using the RAG Tool
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
and automatically chunks them into smaller pieces.
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
[appendix](#more-ragdocument-examples).
```python
from llama_stack_client import RAGDocument
@ -178,3 +179,38 @@ for vector_db_id in client.vector_dbs.list():
print(f"Unregistering vector database: {vector_db_id.identifier}")
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
```
### Appendix
#### More RAGDocument Examples
```python
from llama_stack_client import RAGDocument
import base64
RAGDocument(document_id="num-0", content={"uri": "file://path/to/file"})
RAGDocument(document_id="num-1", content="plain text")
RAGDocument(
document_id="num-2",
content={
"type": "text",
"text": "plain text input",
}, # for inputs that should be treated as text explicitly
)
RAGDocument(
document_id="num-3",
content={
"type": "image",
"image": {"url": {"uri": "https://mywebsite.com/image.jpg"}},
},
)
B64_ENCODED_IMAGE = base64.b64encode(
requests.get(
"https://raw.githubusercontent.com/meta-llama/llama-stack/refs/heads/main/docs/_static/llama-stack.png"
).content
)
RAGDocuemnt(
document_id="num-4",
content={"type": "image", "image": {"data": B64_ENCODED_IMAGE}},
)
```
for more strongly typed interaction use the typed dicts found [here](https://github.com/meta-llama/llama-stack-client-python/blob/38cd91c9e396f2be0bec1ee96a19771582ba6f17/src/llama_stack_client/types/shared_params/document.py).

View file

@ -41,7 +41,7 @@ client.toolgroups.register(
The tool requires an API key which can be provided either in the configuration or through the request header `X-LlamaStack-Provider-Data`. The format of the header is `{"<provider_name>_api_key": <your api key>}`.
> **NOTE:** When using Tavily Search and Bing Search, the inference output will still display "Brave Search." This is because Llama models have been trained with Brave Search as a built-in tool. Tavily and bing is just being used in lieu of Brave search.
#### Code Interpreter
@ -214,3 +214,69 @@ response = agent.create_turn(
session_id=session_id,
)
```
## Simple Example 2: Using an Agent with the Web Search Tool
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
2. [Optional] Provide the API key directly to the Llama Stack server
```bash
export TAVILY_SEARCH_API_KEY="your key"
```
```bash
--env TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY}
```
3. Run the following script.
```python
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(
base_url=f"http://localhost:8321",
provider_data={
"tavily_search_api_key": "your_TAVILY_SEARCH_API_KEY"
}, # Set this from the client side. No need to provide it if it has already been configured on the Llama Stack server.
)
agent = Agent(
client,
model="meta-llama/Llama-3.2-3B-Instruct",
instructions=(
"You are a web search assistant, must use websearch tool to look up the most current and precise information available. "
),
tools=["builtin::websearch"],
)
session_id = agent.create_session("websearch-session")
response = agent.create_turn(
messages=[
{"role": "user", "content": "How did the USA perform in the last Olympics?"}
],
session_id=session_id,
)
for log in EventLogger().log(response):
log.print()
```
## Simple Example3: Using an Agent with the WolframAlpha Tool
1. Start by registering for a WolframAlpha API key at [WolframAlpha Developer Portal](https://developer.wolframalpha.com/access).
2. Provide the API key either when starting the Llama Stack server:
```bash
--env WOLFRAM_ALPHA_API_KEY=${WOLFRAM_ALPHA_API_KEY}
```
or from the client side:
```python
client = LlamaStackClient(
base_url="http://localhost:8321",
provider_data={"wolfram_alpha_api_key": wolfram_api_key},
)
```
3. Configure the tools in the Agent by setting `tools=["builtin::wolfram_alpha"]`.
4. Example user query:
```python
response = agent.create_turn(
messages=[{"role": "user", "content": "Solve x^2 + 2x + 1 = 0 using WolframAlpha"}],
session_id=session_id,
)
```
```

View file

@ -112,6 +112,8 @@ html_theme_options = {
# "style_nav_header_background": "#c3c9d4",
}
default_dark_mode = False
html_static_path = ["../_static"]
# html_logo = "../_static/llama-stack-logo.png"
# html_style = "../_static/css/my_theme.css"
@ -119,6 +121,7 @@ html_static_path = ["../_static"]
def setup(app):
app.add_css_file("css/my_theme.css")
app.add_js_file("js/detect_theme.js")
def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
url = f"https://hub.docker.com/r/llamastack/{text}"

View file

@ -176,7 +176,11 @@ distribution_spec:
safety: inline::llama-guard
agents: inline::meta-reference
telemetry: inline::meta-reference
image_name: ollama
image_type: conda
# If some providers are external, you can specify the path to the implementation
external_providers_dir: /etc/llama-stack/providers.d
```
```
@ -184,6 +188,57 @@ llama stack build --config llama_stack/templates/ollama/build.yaml
```
:::
:::{tab-item} Building with External Providers
Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently or use community-provided providers.
To build a distribution with external providers, you need to:
1. Configure the `external_providers_dir` in your build configuration file:
```yaml
# Example my-external-stack.yaml with external providers
version: '2'
distribution_spec:
description: Custom distro for CI tests
providers:
inference:
- remote::custom_ollama
# Add more providers as needed
image_type: container
image_name: ci-test
# Path to external provider implementations
external_providers_dir: /etc/llama-stack/providers.d
```
Here's an example for a custom Ollama provider:
```yaml
adapter:
adapter_type: custom_ollama
pip_packages:
- ollama
- aiohttp
- llama-stack-provider-ollama # This is the provider package
config_class: llama_stack_ollama_provider.config.OllamaImplConfig
module: llama_stack_ollama_provider
api_dependencies: []
optional_api_dependencies: []
```
The `pip_packages` section lists the Python packages required by the provider, as well as the
provider package itself. The package must be available on PyPI or can be provided from a local
directory or a git repository (git must be installed on the build environment).
2. Build your distribution using the config file:
```
llama stack build --config my-external-stack.yaml
```
For more information on external providers, including directory structure, provider types, and implementation requirements, see the [External Providers documentation](../providers/external.md).
:::
:::{tab-item} Building Container
```{admonition} Podman Alternative
@ -231,7 +286,7 @@ options:
-h, --help show this help message and exit
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
--image-name IMAGE_NAME
Name of the image to run. Defaults to the current conda environment (default: None)
Name of the image to run. Defaults to the current environment (default: None)
--disable-ipv6 Disable IPv6 support (default: False)
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
--tls-keyfile TLS_KEYFILE

View file

@ -2,7 +2,7 @@
The Llama Stack runtime configuration is specified as a YAML file. Here is a simplified version of an example configuration file for the Ollama distribution:
```{dropdown} Sample Configuration File
```{dropdown} 👋 Click here for a Sample Configuration File
```yaml
version: 2

View file

@ -7,13 +7,18 @@ In this guide, we'll use a local [Kind](https://kind.sigs.k8s.io/) cluster and a
First, create a local Kubernetes cluster via Kind:
```bash
```
kind create cluster --image kindest/node:v1.32.0 --name llama-stack-test
```
First, create a Kubernetes PVC and Secret for downloading and storing Hugging Face model:
First set your hugging face token as an environment variable.
```
export HF_TOKEN=$(echo -n "your-hf-token" | base64)
```
```bash
Now create a Kubernetes PVC and Secret for downloading and storing Hugging Face model:
```
cat <<EOF |kubectl apply -f -
apiVersion: v1
kind: PersistentVolumeClaim
@ -33,13 +38,14 @@ metadata:
name: hf-token-secret
type: Opaque
data:
token: $(HF_TOKEN)
token: $HF_TOKEN
EOF
```
Next, start the vLLM server as a Kubernetes Deployment and Service:
```bash
```
cat <<EOF |kubectl apply -f -
apiVersion: apps/v1
kind: Deployment
@ -95,7 +101,7 @@ EOF
We can verify that the vLLM server has started successfully via the logs (this might take a couple of minutes to download the model):
```bash
```
$ kubectl logs -l app.kubernetes.io/name=vllm
...
INFO: Started server process [1]
@ -119,8 +125,8 @@ providers:
Once we have defined the run configuration for Llama Stack, we can build an image with that configuration and the server source code:
```bash
cat >/tmp/test-vllm-llama-stack/Containerfile.llama-stack-run-k8s <<EOF
```
tmp_dir=$(mktemp -d) && cat >$tmp_dir/Containerfile.llama-stack-run-k8s <<EOF
FROM distribution-myenv:dev
RUN apt-get update && apt-get install -y git
@ -128,14 +134,14 @@ RUN git clone https://github.com/meta-llama/llama-stack.git /app/llama-stack-sou
ADD ./vllm-llama-stack-run-k8s.yaml /app/config.yaml
EOF
podman build -f /tmp/test-vllm-llama-stack/Containerfile.llama-stack-run-k8s -t llama-stack-run-k8s /tmp/test-vllm-llama-stack
podman build -f $tmp_dir/Containerfile.llama-stack-run-k8s -t llama-stack-run-k8s $tmp_dir
```
### Deploying Llama Stack Server in Kubernetes
We can then start the Llama Stack server by deploying a Kubernetes Pod and Service:
```bash
```
cat <<EOF |kubectl apply -f -
apiVersion: v1
kind: PersistentVolumeClaim
@ -195,7 +201,7 @@ EOF
### Verifying the Deployment
We can check that the LlamaStack server has started:
```bash
```
$ kubectl logs -l app.kubernetes.io/name=llama-stack
...
INFO: Started server process [1]
@ -207,7 +213,7 @@ INFO: Uvicorn running on http://['::', '0.0.0.0']:5000 (Press CTRL+C to quit
Finally, we forward the Kubernetes service to a local port and test some inference requests against it via the Llama Stack Client:
```bash
```
kubectl port-forward service/llama-stack-service 5000:5000
llama-stack-client --endpoint http://localhost:5000 inference chat-completion --message "hello, what model are you?"
```

View file

@ -24,7 +24,7 @@ The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlama
Add the following dependency in your `build.gradle.kts` file:
```
dependencies {
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.1.4.2")
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.2.2")
}
```
This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/`
@ -37,11 +37,7 @@ For local inferencing, it is required to include the ExecuTorch library into you
Include the ExecuTorch library by:
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
2. Move the script to the top level of your Android app where the app directory resides:
<p align="center">
<img src="https://github.com/meta-llama/llama-stack-client-kotlin/blob/latest-release/doc/img/example_android_app_directory.png" style="width:300px">
</p>
2. Move the script to the top level of your Android app where the `app` directory resides.
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate.
4. Add the `executorch.aar` dependency in your `build.gradle.kts` file:
```
@ -52,6 +48,8 @@ dependencies {
}
```
See other dependencies for the local RAG in Android app [README](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#quick-start).
## Llama Stack APIs in Your Android App
Breaking down the demo app, this section will show the core pieces that are used to initialize and run inference with Llama Stack using the Kotlin library.
@ -60,7 +58,7 @@ Start a Llama Stack server on localhost. Here is an example of how you can do th
```
conda create -n stack-fireworks python=3.10
conda activate stack-fireworks
pip install --no-cache llama-stack==0.1.4
pip install --no-cache llama-stack==0.2.2
llama stack build --template fireworks --image-type conda
export FIREWORKS_API_KEY=<SOME_KEY>
llama stack run fireworks --port 5050

View file

@ -1,88 +0,0 @@
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
# NVIDIA Distribution
The `llamastack/distribution-nvidia` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::nvidia` |
| post_training | `remote::nvidia` |
| safety | `remote::nvidia` |
| scoring | `inline::basic` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `inline::rag-runtime` |
| vector_io | `inline::faiss` |
### Environment Variables
The following environment variables can be configured:
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
### Models
The following models are available by default:
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
- `nvidia/nv-embedqa-e5-v5 `
- `nvidia/nv-embedqa-mistral-7b-v2 `
- `snowflake/arctic-embed-l `
### Prerequisite: API Keys
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/).
## Running Llama Stack with NVIDIA
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-nvidia \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
```
### Via Conda
```bash
llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
--env INFERENCE_MODEL=$INFERENCE_MODEL
```

View file

@ -43,7 +43,9 @@ The following models are available by default:
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `groq/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
- `groq/meta-llama/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
- `groq/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
- `groq/meta-llama/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
### Prerequisite: API Keys

View file

@ -1,3 +1,4 @@
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
# NVIDIA Distribution
The `llamastack/distribution-nvidia` distribution consists of the following provider configurations.
@ -5,34 +6,130 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::nvidia` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| post_training | `remote::nvidia` |
| safety | `remote::nvidia` |
| scoring | `inline::basic` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `inline::rag-runtime` |
| vector_io | `inline::faiss` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
### Models
The following models are available by default:
- `${env.INFERENCE_MODEL} (None)`
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
- `meta/llama-3.3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
- `nvidia/nv-embedqa-e5-v5 `
- `nvidia/nv-embedqa-mistral-7b-v2 `
- `snowflake/arctic-embed-l `
### Prerequisite: API Keys
## Prerequisites
### NVIDIA API Keys
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/).
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
### Deploy NeMo Microservices Platform
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
## Supported Services
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
### Inference: NVIDIA NIM
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
### Datasetio API: NeMo Data Store
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
See the [NVIDIA Datasetio docs](/llama_stack/providers/remote/datasetio/nvidia/README.md) for supported features and example usage.
### Eval API: NeMo Evaluator
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
See the [NVIDIA Eval docs](/llama_stack/providers/remote/eval/nvidia/README.md) for supported features and example usage.
### Post-Training API: NeMo Customizer
The NeMo Customizer microservice supports fine-tuning models. You can reference [this list of supported models](/llama_stack/providers/remote/post_training/nvidia/models.py) that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
See the [NVIDIA Post-Training docs](/llama_stack/providers/remote/post_training/nvidia/README.md) for supported features and example usage.
### Safety API: NeMo Guardrails
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
See the NVIDIA Safety docs for supported features and example usage.
## Deploying models
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
```sh
# URL to NeMo NIM Proxy service
export NEMO_URL="http://nemo.test"
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"name": "llama-3.2-1b-instruct",
"namespace": "meta",
"config": {
"model": "meta/llama-3.2-1b-instruct",
"nim_deployment": {
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
"image_tag": "1.8.3",
"pvc_size": "25Gi",
"gpu": 1,
"additional_envs": {
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
}
}
}
}'
```
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
You can also remove a deployed NIM to free up GPU resources, if needed.
```sh
export NEMO_URL="http://nemo.test"
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
```
## Running Llama Stack with NVIDIA
You can do this via Conda (build code) or Docker which has a pre-built image.
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
### Via Docker
@ -54,8 +151,23 @@ docker run \
### Via Conda
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
llama stack build --template nvidia --image-type conda
llama stack run ./run.yaml \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
llama stack build --template nvidia --image-type venv
llama stack run ./run.yaml \
--port 8321 \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
--env INFERENCE_MODEL=$INFERENCE_MODEL
```

View file

@ -25,7 +25,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
You can use this distribution if you want to run an independent vLLM server for inference.
### Environment Variables
@ -41,7 +41,10 @@ The following environment variables can be configured:
## Setting up vLLM server
Both AMD and NVIDIA GPUs can serve as accelerators for the vLLM server, which acts as both the LLM inference provider and the safety provider.
In the following sections, we'll use AMD, NVIDIA or Intel GPUs to serve as hardware accelerators for the vLLM
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
that we only use GPUs here for demonstration purposes. Note that if you run into issues, you can include the environment variable `--env VLLM_DEBUG_LOG_API_SERVER_RESPONSE=true` (available in vLLM v0.8.3 and above) in the `docker run` command to enable log response from API server for debugging.
### Setting up vLLM server on AMD GPU
@ -159,6 +162,55 @@ docker run \
--port $SAFETY_PORT
```
### Setting up vLLM server on Intel GPU
Refer to [vLLM Documentation for XPU](https://docs.vllm.ai/en/v0.8.2/getting_started/installation/gpu.html?device=xpu) to get a vLLM endpoint. In addition to vLLM side setup which guides towards installing vLLM from sources orself-building vLLM Docker container, Intel provides prebuilt vLLM container to use on systems with Intel GPUs supported by PyTorch XPU backend:
- [intel/vllm](https://hub.docker.com/r/intel/vllm)
Here is a sample script to start a vLLM server locally via Docker using Intel provided container:
```bash
export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
export ZE_AFFINITY_MASK=0
docker run \
--pull always \
--device /dev/dri \
-v /dev/dri/by-path:/dev/dri/by-path \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
-p $INFERENCE_PORT:$INFERENCE_PORT \
--ipc=host \
intel/vllm:xpu \
--gpu-memory-utilization 0.7 \
--model $INFERENCE_MODEL \
--port $INFERENCE_PORT
```
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
```bash
export SAFETY_PORT=8081
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export ZE_AFFINITY_MASK=1
docker run \
--pull always \
--device /dev/dri \
-v /dev/dri/by-path:/dev/dri/by-path \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
-p $SAFETY_PORT:$SAFETY_PORT \
--ipc=host \
intel/vllm:xpu \
--gpu-memory-utilization 0.7 \
--model $SAFETY_MODEL \
--port $SAFETY_PORT
```
## Running Llama Stack
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.

View file

@ -2,22 +2,22 @@
You can run a Llama Stack server in one of the following ways:
**As a Library**:
## As a Library:
This is the simplest way to get started. Using Llama Stack as a library means you do not need to start a server. This is especially useful when you are not running inference locally and relying on an external inference service (eg. fireworks, together, groq, etc.) See [Using Llama Stack as a Library](importing_as_library)
**Container**:
## Container:
Another simple way to start interacting with Llama Stack is to just spin up a container (via Docker or Podman) which is pre-built with all the providers you need. We provide a number of pre-built images so you can start a Llama Stack server instantly. You can also build your own custom container. Which distribution to choose depends on the hardware you have. See [Selection of a Distribution](selection) for more details.
**Conda**:
## Conda:
If you have a custom or an advanced setup or you are developing on Llama Stack you can also build a custom Llama Stack server. Using `llama stack build` and `llama stack run` you can build/run a custom Llama Stack server containing the exact combination of providers you wish. We have also provided various templates to make getting started easier. See [Building a Custom Distribution](building_distro) for more details.
**Kubernetes**:
## Kubernetes:
If you have built a container image and want to deploy it in a Kubernetes cluster instead of starting the Llama Stack server locally. See [Kubernetes Deployment Guide](kubernetes_deployment) for more details.

View file

@ -0,0 +1,541 @@
# Detailed Tutorial
In this guide, we'll walk through how you can use the Llama Stack (server and client SDK) to test a simple agent.
A Llama Stack agent is a simple integrated system that can perform tasks by combining a Llama model for reasoning with
tools (e.g., RAG, web search, code execution, etc.) for taking actions.
In Llama Stack, we provide a server exposing multiple APIs. These APIs are backed by implementations from different providers.
Llama Stack is a stateful service with REST APIs to support seamless transition of AI applications across different environments. The server can be run in a variety of ways, including as a standalone binary, Docker container, or hosted service. You can build and test using a local server first and deploy to a hosted endpoint for production.
In this guide, we'll walk through how to build a RAG agent locally using Llama Stack with [Ollama](https://ollama.com/)
as the inference [provider](../providers/index.md#inference) for a Llama Model.
## Step 1: Installation and Setup
Install Ollama by following the instructions on the [Ollama website](https://ollama.com/download), then
download Llama 3.2 3B model, and then start the Ollama service.
```bash
ollama pull llama3.2:3b
ollama run llama3.2:3b --keepalive 60m
```
Install [uv](https://docs.astral.sh/uv/) to setup your virtual environment
::::{tab-set}
:::{tab-item} macOS and Linux
Use `curl` to download the script and execute it with `sh`:
```console
curl -LsSf https://astral.sh/uv/install.sh | sh
```
:::
:::{tab-item} Windows
Use `irm` to download the script and execute it with `iex`:
```console
powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
```
:::
::::
Setup your virtual environment.
```bash
uv venv --python 3.10
source .venv/bin/activate
```
## Step 2: Run Llama Stack
Llama Stack is a server that exposes multiple APIs, you connect with it using the Llama Stack client SDK.
::::{tab-set}
:::{tab-item} Using `venv`
You can use Python to build and run the Llama Stack server, which is useful for testing and development.
Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup,
which defines the providers and their settings.
Now let's build and run the Llama Stack config for Ollama.
```bash
INFERENCE_MODEL=llama3.2:3b llama stack build --template ollama --image-type venv --run
```
:::
:::{tab-item} Using `conda`
You can use Python to build and run the Llama Stack server, which is useful for testing and development.
Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup,
which defines the providers and their settings.
Now let's build and run the Llama Stack config for Ollama.
```bash
INFERENCE_MODEL=llama3.2:3b llama stack build --template ollama --image-type conda --image-name llama3-3b-conda --run
```
:::
:::{tab-item} Using a Container
You can use a container image to run the Llama Stack server. We provide several container images for the server
component that works with different inference providers out of the box. For this guide, we will use
`llamastack/distribution-ollama` as the container image. If you'd like to build your own image or customize the
configurations, please check out [this guide](../references/index.md).
First lets setup some environment variables and create a local directory to mount into the containers file system.
```bash
export INFERENCE_MODEL="llama3.2:3b"
export LLAMA_STACK_PORT=8321
mkdir -p ~/.llama
```
Then start the server using the container tool of your choice. For example, if you are running Docker you can use the
following command:
```bash
docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-ollama \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env OLLAMA_URL=http://host.docker.internal:11434
```
Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with
`podman`. If you are using `podman` older than `4.7.0`, please also replace `host.docker.internal` in the `OLLAMA_URL`
with `host.containers.internal`.
The configuration YAML for the Ollama distribution is available at `distributions/ollama/run.yaml`.
```{tip}
Docker containers run in their own isolated network namespaces on Linux. To allow the container to communicate with services running on the host via `localhost`, you need `--network=host`. This makes the container use the hosts network directly so it can connect to Ollama running on `localhost:11434`.
Linux users having issues running the above command should instead try the following:
```bash
docker run -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
--network=host \
llamastack/distribution-ollama \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env OLLAMA_URL=http://localhost:11434
```
:::
::::
You will see output like below:
```
INFO: Application startup complete.
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
```
Now you can use the Llama Stack client to run inference and build agents!
You can reuse the server setup or use the [Llama Stack Client](https://github.com/meta-llama/llama-stack-client-python/).
Note that the client package is already included in the `llama-stack` package.
## Step 3: Run Client CLI
Open a new terminal and navigate to the same directory you started the server from. Then set up a new or activate your
existing server virtual environment.
::::{tab-set}
:::{tab-item} Reuse Server `venv`
```bash
# The client is included in the llama-stack package so we just activate the server venv
source .venv/bin/activate
```
:::
:::{tab-item} Install with `venv`
```bash
uv venv client --python 3.10
source client/bin/activate
pip install llama-stack-client
```
:::
:::{tab-item} Install with `conda`
```bash
yes | conda create -n stack-client python=3.10
conda activate stack-client
pip install llama-stack-client
```
:::
::::
Now let's use the `llama-stack-client` [CLI](../references/llama_stack_client_cli_reference.md) to check the
connectivity to the server.
```bash
llama-stack-client configure --endpoint http://localhost:8321 --api-key none
```
You will see the below:
```
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
```
List the models
```bash
llama-stack-client models list
Available Models
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ embedding │ all-MiniLM-L6-v2 │ all-minilm:latest │ {'embedding_dimension': 384.0} │ ollama │
├─────────────────┼─────────────────────────────────────┼─────────────────────────────────────┼───────────────────────────────────────────┼─────────────────┤
│ llm │ llama3.2:3b │ llama3.2:3b │ │ ollama │
└─────────────────┴─────────────────────────────────────┴─────────────────────────────────────┴───────────────────────────────────────────┴─────────────────┘
Total models: 2
```
You can test basic Llama inference completion using the CLI.
```bash
llama-stack-client inference chat-completion --message "tell me a joke"
```
Sample output:
```python
ChatCompletionResponse(
completion_message=CompletionMessage(
content="Here's one:\n\nWhat do you call a fake noodle?\n\nAn impasta!",
role="assistant",
stop_reason="end_of_turn",
tool_calls=[],
),
logprobs=None,
metrics=[
Metric(metric="prompt_tokens", value=14.0, unit=None),
Metric(metric="completion_tokens", value=27.0, unit=None),
Metric(metric="total_tokens", value=41.0, unit=None),
],
)
```
## Step 4: Run the Demos
Note that these demos show the [Python Client SDK](../references/python_sdk_reference/index.md).
Other SDKs are also available, please refer to the [Client SDK](../index.md#client-sdks) list for the complete options.
::::{tab-set}
:::{tab-item} Basic Inference
Now you can run inference using the Llama Stack client SDK.
### i. Create the Script
Create a file `inference.py` and add the following code:
```python
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url="http://localhost:8321")
# List available models
models = client.models.list()
# Select the first LLM
llm = next(m for m in models if m.model_type == "llm")
model_id = llm.identifier
print("Model:", model_id)
response = client.inference.chat_completion(
model_id=model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a haiku about coding"},
],
)
print(response.completion_message.content)
```
### ii. Run the Script
Let's run the script using `uv`
```bash
uv run python inference.py
```
Which will output:
```
Model: llama3.2:3b
Here is a haiku about coding:
Lines of code unfold
Logic flows through digital night
Beauty in the bits
```
:::
:::{tab-item} Build a Simple Agent
Next we can move beyond simple inference and build an agent that can perform tasks using the Llama Stack server.
### i. Create the Script
Create a file `agent.py` and add the following code:
```python
from llama_stack_client import LlamaStackClient
from llama_stack_client import Agent, AgentEventLogger
from rich.pretty import pprint
import uuid
client = LlamaStackClient(base_url=f"http://localhost:8321")
models = client.models.list()
llm = next(m for m in models if m.model_type == "llm")
model_id = llm.identifier
agent = Agent(client, model=model_id, instructions="You are a helpful assistant.")
s_id = agent.create_session(session_name=f"s{uuid.uuid4().hex}")
print("Non-streaming ...")
response = agent.create_turn(
messages=[{"role": "user", "content": "Who are you?"}],
session_id=s_id,
stream=False,
)
print("agent>", response.output_message.content)
print("Streaming ...")
stream = agent.create_turn(
messages=[{"role": "user", "content": "Who are you?"}], session_id=s_id, stream=True
)
for event in stream:
pprint(event)
print("Streaming with print helper...")
stream = agent.create_turn(
messages=[{"role": "user", "content": "Who are you?"}], session_id=s_id, stream=True
)
for event in AgentEventLogger().log(stream):
event.print()
```
### ii. Run the Script
Let's run the script using `uv`
```bash
uv run python agent.py
```
```{dropdown} 👋 Click here to see the sample output
Non-streaming ...
agent> I'm an artificial intelligence designed to assist and communicate with users like you. I don't have a personal identity, but I'm here to provide information, answer questions, and help with tasks to the best of my abilities.
I can be used for a wide range of purposes, such as:
* Providing definitions and explanations
* Offering suggestions and ideas
* Helping with language translation
* Assisting with writing and proofreading
* Generating text or responses to questions
* Playing simple games or chatting about topics of interest
I'm constantly learning and improving my abilities, so feel free to ask me anything, and I'll do my best to help!
Streaming ...
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseStepStartPayload(
│ │ │ event_type='step_start',
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ step_type='inference',
│ │ │ metadata={}
│ │ )
│ )
)
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseStepProgressPayload(
│ │ │ delta=TextDelta(text='As', type='text'),
│ │ │ event_type='step_progress',
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ step_type='inference'
│ │ )
│ )
)
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseStepProgressPayload(
│ │ │ delta=TextDelta(text=' a', type='text'),
│ │ │ event_type='step_progress',
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ step_type='inference'
│ │ )
│ )
)
...
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseStepCompletePayload(
│ │ │ event_type='step_complete',
│ │ │ step_details=InferenceStep(
│ │ │ │ api_model_response=CompletionMessage(
│ │ │ │ │ content='As a conversational AI, I don\'t have a personal identity in the classical sense. I exist as a program running on computer servers, designed to process and respond to text-based inputs.\n\nI\'m an instance of a type of artificial intelligence called a "language model," which is trained on vast amounts of text data to generate human-like responses. My primary function is to understand and respond to natural language inputs, like our conversation right now.\n\nThink of me as a virtual assistant, a chatbot, or a conversational interface I\'m here to provide information, answer questions, and engage in conversation to the best of my abilities. I don\'t have feelings, emotions, or consciousness like humans do, but I\'m designed to simulate human-like interactions to make our conversations feel more natural and helpful.\n\nSo, that\'s me in a nutshell! What can I help you with today?',
│ │ │ │ │ role='assistant',
│ │ │ │ │ stop_reason='end_of_turn',
│ │ │ │ │ tool_calls=[]
│ │ │ │ ),
│ │ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ │ step_type='inference',
│ │ │ │ turn_id='8b360202-f7cb-4786-baa9-166a1b46e2ca',
│ │ │ │ completed_at=datetime.datetime(2025, 4, 3, 1, 15, 21, 716174, tzinfo=TzInfo(UTC)),
│ │ │ │ started_at=datetime.datetime(2025, 4, 3, 1, 15, 14, 28823, tzinfo=TzInfo(UTC))
│ │ │ ),
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ step_type='inference'
│ │ )
│ )
)
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseTurnCompletePayload(
│ │ │ event_type='turn_complete',
│ │ │ turn=Turn(
│ │ │ │ input_messages=[UserMessage(content='Who are you?', role='user', context=None)],
│ │ │ │ output_message=CompletionMessage(
│ │ │ │ │ content='As a conversational AI, I don\'t have a personal identity in the classical sense. I exist as a program running on computer servers, designed to process and respond to text-based inputs.\n\nI\'m an instance of a type of artificial intelligence called a "language model," which is trained on vast amounts of text data to generate human-like responses. My primary function is to understand and respond to natural language inputs, like our conversation right now.\n\nThink of me as a virtual assistant, a chatbot, or a conversational interface I\'m here to provide information, answer questions, and engage in conversation to the best of my abilities. I don\'t have feelings, emotions, or consciousness like humans do, but I\'m designed to simulate human-like interactions to make our conversations feel more natural and helpful.\n\nSo, that\'s me in a nutshell! What can I help you with today?',
│ │ │ │ │ role='assistant',
│ │ │ │ │ stop_reason='end_of_turn',
│ │ │ │ │ tool_calls=[]
│ │ │ │ ),
│ │ │ │ session_id='abd4afea-4324-43f4-9513-cfe3970d92e8',
│ │ │ │ started_at=datetime.datetime(2025, 4, 3, 1, 15, 14, 28722, tzinfo=TzInfo(UTC)),
│ │ │ │ steps=[
│ │ │ │ │ InferenceStep(
│ │ │ │ │ │ api_model_response=CompletionMessage(
│ │ │ │ │ │ │ content='As a conversational AI, I don\'t have a personal identity in the classical sense. I exist as a program running on computer servers, designed to process and respond to text-based inputs.\n\nI\'m an instance of a type of artificial intelligence called a "language model," which is trained on vast amounts of text data to generate human-like responses. My primary function is to understand and respond to natural language inputs, like our conversation right now.\n\nThink of me as a virtual assistant, a chatbot, or a conversational interface I\'m here to provide information, answer questions, and engage in conversation to the best of my abilities. I don\'t have feelings, emotions, or consciousness like humans do, but I\'m designed to simulate human-like interactions to make our conversations feel more natural and helpful.\n\nSo, that\'s me in a nutshell! What can I help you with today?',
│ │ │ │ │ │ │ role='assistant',
│ │ │ │ │ │ │ stop_reason='end_of_turn',
│ │ │ │ │ │ │ tool_calls=[]
│ │ │ │ │ │ ),
│ │ │ │ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ │ │ │ step_type='inference',
│ │ │ │ │ │ turn_id='8b360202-f7cb-4786-baa9-166a1b46e2ca',
│ │ │ │ │ │ completed_at=datetime.datetime(2025, 4, 3, 1, 15, 21, 716174, tzinfo=TzInfo(UTC)),
│ │ │ │ │ │ started_at=datetime.datetime(2025, 4, 3, 1, 15, 14, 28823, tzinfo=TzInfo(UTC))
│ │ │ │ │ )
│ │ │ │ ],
│ │ │ │ turn_id='8b360202-f7cb-4786-baa9-166a1b46e2ca',
│ │ │ │ completed_at=datetime.datetime(2025, 4, 3, 1, 15, 21, 727364, tzinfo=TzInfo(UTC)),
│ │ │ │ output_attachments=[]
│ │ │ )
│ │ )
│ )
)
Streaming with print helper...
inference> Déjà vu!
As I mentioned earlier, I'm an artificial intelligence language model. I don't have a personal identity or consciousness like humans do. I exist solely to process and respond to text-based inputs, providing information and assistance on a wide range of topics.
I'm a computer program designed to simulate human-like conversations, using natural language processing (NLP) and machine learning algorithms to understand and generate responses. My purpose is to help users like you with their questions, provide information, and engage in conversation.
Think of me as a virtual companion, a helpful tool designed to make your interactions more efficient and enjoyable. I don't have personal opinions, emotions, or biases, but I'm here to provide accurate and informative responses to the best of my abilities.
So, who am I? I'm just a computer program designed to help you!
```
:::
:::{tab-item} Build a RAG Agent
For our last demo, we can build a RAG agent that can answer questions about the Torchtune project using the documents
in a vector database.
### i. Create the Script
Create a file `rag_agent.py` and add the following code:
```python
from llama_stack_client import LlamaStackClient
from llama_stack_client import Agent, AgentEventLogger
from llama_stack_client.types import Document
import uuid
from termcolor import cprint
client = LlamaStackClient(base_url="http://localhost:8321")
# Create a vector database instance
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
embedding_model = embed_lm.identifier
vector_db_id = f"v{uuid.uuid4().hex}"
client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
)
# Create Documents
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
# Insert documents
client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
)
# Get the model being served
llm = next(m for m in client.models.list() if m.model_type == "llm")
model = llm.identifier
# Create the RAG agent
rag_agent = Agent(
client,
model=model,
instructions="You are a helpful assistant. Use the RAG tool to answer questions as needed.",
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {"vector_db_ids": [vector_db_id]},
}
],
)
session_id = rag_agent.create_session(session_name=f"s{uuid.uuid4().hex}")
turns = ["what is torchtune", "tell me about dora"]
for t in turns:
print("user>", t)
stream = rag_agent.create_turn(
messages=[{"role": "user", "content": t}], session_id=session_id, stream=True
)
for event in AgentEventLogger().log(stream):
event.print()
```
### ii. Run the Script
Let's run the script using `uv`
```bash
uv run python rag_agent.py
```
```{dropdown} 👋 Click here to see the sample output
user> what is torchtune
inference> [knowledge_search(query='TorchTune')]
tool_execution> Tool:knowledge_search Args:{'query': 'TorchTune'}
tool_execution> Tool:knowledge_search Response:[TextContentItem(text='knowledge_search tool found 5 chunks:\nBEGIN of knowledge_search tool results.\n', type='text'), TextContentItem(text='Result 1:\nDocument_id:num-1\nContent: conversational data, :func:`~torchtune.datasets.chat_dataset` seems to be a good fit. ..., type='text'), TextContentItem(text='END of knowledge_search tool results.\n', type='text')]
inference> Here is a high-level overview of the text:
**LoRA Finetuning with PyTorch Tune**
PyTorch Tune provides a recipe for LoRA (Low-Rank Adaptation) finetuning, which is a technique to adapt pre-trained models to new tasks. The recipe uses the `lora_finetune_distributed` command.
...
Overall, DORA is a powerful reinforcement learning algorithm that can learn complex tasks from human demonstrations. However, it requires careful consideration of the challenges and limitations to achieve optimal results.
```
:::
::::
**You're Ready to Build Your Own Apps!**
Congrats! 🥳 Now you're ready to [build your own Llama Stack applications](../building_applications/index)! 🚀

View file

@ -1,414 +1,65 @@
# Quick Start
# Quickstart
Get started with Llama Stack in minutes!
Llama Stack is a stateful service with REST APIs to support seamless transition of AI applications across different environments. The server can be run in a variety of ways, including as a standalone binary, Docker container, or hosted service. You can build and test using a local server first and deploy to a hosted endpoint for production.
Llama Stack is a stateful service with REST APIs to support the seamless transition of AI applications across different
environments. You can build and test using a local server first and deploy to a hosted endpoint for production.
In this guide, we'll walk through how to build a RAG agent locally using Llama Stack with [Ollama](https://ollama.com/) to run inference on a Llama Model.
### 1. Download a Llama model with Ollama
In this guide, we'll walk through how to build a RAG application locally using Llama Stack with [Ollama](https://ollama.com/)
as the inference [provider](../providers/index.md#inference) for a Llama Model.
#### Step 1: Install and setup
1. Install [uv](https://docs.astral.sh/uv/)
2. Run inference on a Llama model with [Ollama](https://ollama.com/download)
```bash
ollama pull llama3.2:3b-instruct-fp16
ollama run llama3.2:3b --keepalive 60m
```
This will instruct the Ollama service to download the Llama 3.2 3B Instruct model, which we'll use in the rest of this guide.
```{admonition} Note
:class: tip
If you do not have ollama, you can install it from [here](https://ollama.com/download).
```
### 2. Run Llama Stack locally
We use `uv` to setup a virtual environment and install the Llama Stack package.
:::{dropdown} [Click to Open] Instructions to setup uv
Install [uv](https://docs.astral.sh/uv/) to setup your virtual environment.
#### For macOS and Linux:
#### Step 2: Run the Llama Stack server
We will use `uv` to run the Llama Stack server.
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
#### For Windows:
Use `irm` to download the script and execute it with `iex`:
```powershell
powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template ollama --image-type venv --run
```
#### Step 3: Run the demo
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
Setup venv
```bash
uv venv --python 3.10
source .venv/bin/activate
```
:::
**Install the Llama Stack package**
```bash
uv pip install -U llama-stack
```
**Build and Run the Llama Stack server for Ollama.**
```bash
INFERENCE_MODEL=llama3.2:3b llama stack build --template ollama --image-type venv --run
```
You will see the output end like below:
```
...
INFO: Application startup complete.
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
```
Now you can use the llama stack client to run inference and build agents!
### 3. Client CLI
Install the client package
```bash
pip install llama-stack-client
```
:::{dropdown} OR reuse server setup
Open a new terminal and navigate to the same directory you started the server from.
Setup venv (llama-stack already includes the llama-stack-client package)
```bash
source .venv/bin/activate
```
:::
#### 3.1 Configure the client to point to the local server
```bash
llama-stack-client configure --endpoint http://localhost:8321 --api-key none
```
You will see the below:
```
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
```
#### 3.2 List available models
```
llama-stack-client models list
```
```
Available Models
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ embedding │ all-MiniLM-L6-v2 │ all-minilm:latest │ {'embedding_dimension': 384.0} │ ollama │
├─────────────────┼─────────────────────────────────────┼─────────────────────────────────────┼───────────────────────────────────────────┼─────────────────┤
│ llm │ llama3.2:3b │ llama3.2:3b │ │ ollama │
└─────────────────┴─────────────────────────────────────┴─────────────────────────────────────┴───────────────────────────────────────────┴─────────────────┘
Total models: 2
```
#### 3.3 Test basic inference
```bash
llama-stack-client inference chat-completion --message "tell me a joke"
```
Sample output:
```python
ChatCompletionResponse(
completion_message=CompletionMessage(
content="Here's one:\n\nWhat do you call a fake noodle?\n\nAn impasta!",
role="assistant",
stop_reason="end_of_turn",
tool_calls=[],
),
logprobs=None,
metrics=[
Metric(metric="prompt_tokens", value=14.0, unit=None),
Metric(metric="completion_tokens", value=27.0, unit=None),
Metric(metric="total_tokens", value=41.0, unit=None),
],
)
```
from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient
### 4. Python SDK
Install the python client
```bash
pip install llama-stack-client
```
:::{dropdown} OR reuse server setup
Open a new terminal and navigate to the same directory you started the server from.
vector_db_id = "my_demo_vector_db"
client = LlamaStackClient(base_url="http://localhost:8321")
Setup venv (llama-stack already includes the llama-stack-client package)
```bash
source .venv/bin/activate
```
:::
#### 4.1 Basic Inference
Create a file `inference.py` and add the following code:
```python
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url=f"http://localhost:8321")
# List available models
models = client.models.list()
# Select the first LLM
llm = next(m for m in models if m.model_type == "llm")
model_id = llm.identifier
# Select the first LLM and first embedding models
model_id = next(m for m in models if m.model_type == "llm").identifier
embedding_model_id = (
em := next(m for m in models if m.model_type == "embedding")
).identifier
embedding_dimension = em.metadata["embedding_dimension"]
print("Model:", model_id)
response = client.inference.chat_completion(
model_id=model_id,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write a haiku about coding"},
],
)
print(response.completion_message.content)
```
Run the script
```bash
python inference.py
```
Sample output:
```
Model: llama3.2:3b-instruct-fp16
Here is a haiku about coding:
Lines of code unfold
Logic flows through digital night
Beauty in the bits
```
#### 4.2. Basic Agent
Create a file `agent.py` and add the following code:
```python
from llama_stack_client import LlamaStackClient
from llama_stack_client import Agent, AgentEventLogger
from rich.pretty import pprint
import uuid
client = LlamaStackClient(base_url=f"http://localhost:8321")
models = client.models.list()
llm = next(m for m in models if m.model_type == "llm")
model_id = llm.identifier
agent = Agent(client, model=model_id, instructions="You are a helpful assistant.")
s_id = agent.create_session(session_name=f"s{uuid.uuid4().hex}")
print("Non-streaming ...")
response = agent.create_turn(
messages=[{"role": "user", "content": "Who are you?"}],
session_id=s_id,
stream=False,
)
print("agent>", response.output_message.content)
print("Streaming ...")
stream = agent.create_turn(
messages=[{"role": "user", "content": "Who are you?"}], session_id=s_id, stream=True
)
for event in stream:
pprint(event)
print("Streaming with print helper...")
stream = agent.create_turn(
messages=[{"role": "user", "content": "Who are you?"}], session_id=s_id, stream=True
)
for event in AgentEventLogger().log(stream):
event.print()
```
Run the script:
```bash
python agent.py
```
:::{dropdown} `Sample output`
```
Non-streaming ...
agent> I'm an artificial intelligence designed to assist and communicate with users like you. I don't have a personal identity, but I'm here to provide information, answer questions, and help with tasks to the best of my abilities.
I can be used for a wide range of purposes, such as:
* Providing definitions and explanations
* Offering suggestions and ideas
* Helping with language translation
* Assisting with writing and proofreading
* Generating text or responses to questions
* Playing simple games or chatting about topics of interest
I'm constantly learning and improving my abilities, so feel free to ask me anything, and I'll do my best to help!
Streaming ...
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseStepStartPayload(
│ │ │ event_type='step_start',
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ step_type='inference',
│ │ │ metadata={}
│ │ )
│ )
)
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseStepProgressPayload(
│ │ │ delta=TextDelta(text='As', type='text'),
│ │ │ event_type='step_progress',
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ step_type='inference'
│ │ )
│ )
)
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseStepProgressPayload(
│ │ │ delta=TextDelta(text=' a', type='text'),
│ │ │ event_type='step_progress',
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ step_type='inference'
│ │ )
│ )
)
...
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseStepCompletePayload(
│ │ │ event_type='step_complete',
│ │ │ step_details=InferenceStep(
│ │ │ │ api_model_response=CompletionMessage(
│ │ │ │ │ content='As a conversational AI, I don\'t have a personal identity in the classical sense. I exist as a program running on computer servers, designed to process and respond to text-based inputs.\n\nI\'m an instance of a type of artificial intelligence called a "language model," which is trained on vast amounts of text data to generate human-like responses. My primary function is to understand and respond to natural language inputs, like our conversation right now.\n\nThink of me as a virtual assistant, a chatbot, or a conversational interface I\'m here to provide information, answer questions, and engage in conversation to the best of my abilities. I don\'t have feelings, emotions, or consciousness like humans do, but I\'m designed to simulate human-like interactions to make our conversations feel more natural and helpful.\n\nSo, that\'s me in a nutshell! What can I help you with today?',
│ │ │ │ │ role='assistant',
│ │ │ │ │ stop_reason='end_of_turn',
│ │ │ │ │ tool_calls=[]
│ │ │ │ ),
│ │ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ │ step_type='inference',
│ │ │ │ turn_id='8b360202-f7cb-4786-baa9-166a1b46e2ca',
│ │ │ │ completed_at=datetime.datetime(2025, 4, 3, 1, 15, 21, 716174, tzinfo=TzInfo(UTC)),
│ │ │ │ started_at=datetime.datetime(2025, 4, 3, 1, 15, 14, 28823, tzinfo=TzInfo(UTC))
│ │ │ ),
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ step_type='inference'
│ │ )
│ )
)
AgentTurnResponseStreamChunk(
│ event=TurnResponseEvent(
│ │ payload=AgentTurnResponseTurnCompletePayload(
│ │ │ event_type='turn_complete',
│ │ │ turn=Turn(
│ │ │ │ input_messages=[UserMessage(content='Who are you?', role='user', context=None)],
│ │ │ │ output_message=CompletionMessage(
│ │ │ │ │ content='As a conversational AI, I don\'t have a personal identity in the classical sense. I exist as a program running on computer servers, designed to process and respond to text-based inputs.\n\nI\'m an instance of a type of artificial intelligence called a "language model," which is trained on vast amounts of text data to generate human-like responses. My primary function is to understand and respond to natural language inputs, like our conversation right now.\n\nThink of me as a virtual assistant, a chatbot, or a conversational interface I\'m here to provide information, answer questions, and engage in conversation to the best of my abilities. I don\'t have feelings, emotions, or consciousness like humans do, but I\'m designed to simulate human-like interactions to make our conversations feel more natural and helpful.\n\nSo, that\'s me in a nutshell! What can I help you with today?',
│ │ │ │ │ role='assistant',
│ │ │ │ │ stop_reason='end_of_turn',
│ │ │ │ │ tool_calls=[]
│ │ │ │ ),
│ │ │ │ session_id='abd4afea-4324-43f4-9513-cfe3970d92e8',
│ │ │ │ started_at=datetime.datetime(2025, 4, 3, 1, 15, 14, 28722, tzinfo=TzInfo(UTC)),
│ │ │ │ steps=[
│ │ │ │ │ InferenceStep(
│ │ │ │ │ │ api_model_response=CompletionMessage(
│ │ │ │ │ │ │ content='As a conversational AI, I don\'t have a personal identity in the classical sense. I exist as a program running on computer servers, designed to process and respond to text-based inputs.\n\nI\'m an instance of a type of artificial intelligence called a "language model," which is trained on vast amounts of text data to generate human-like responses. My primary function is to understand and respond to natural language inputs, like our conversation right now.\n\nThink of me as a virtual assistant, a chatbot, or a conversational interface I\'m here to provide information, answer questions, and engage in conversation to the best of my abilities. I don\'t have feelings, emotions, or consciousness like humans do, but I\'m designed to simulate human-like interactions to make our conversations feel more natural and helpful.\n\nSo, that\'s me in a nutshell! What can I help you with today?',
│ │ │ │ │ │ │ role='assistant',
│ │ │ │ │ │ │ stop_reason='end_of_turn',
│ │ │ │ │ │ │ tool_calls=[]
│ │ │ │ │ │ ),
│ │ │ │ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
│ │ │ │ │ │ step_type='inference',
│ │ │ │ │ │ turn_id='8b360202-f7cb-4786-baa9-166a1b46e2ca',
│ │ │ │ │ │ completed_at=datetime.datetime(2025, 4, 3, 1, 15, 21, 716174, tzinfo=TzInfo(UTC)),
│ │ │ │ │ │ started_at=datetime.datetime(2025, 4, 3, 1, 15, 14, 28823, tzinfo=TzInfo(UTC))
│ │ │ │ │ )
│ │ │ │ ],
│ │ │ │ turn_id='8b360202-f7cb-4786-baa9-166a1b46e2ca',
│ │ │ │ completed_at=datetime.datetime(2025, 4, 3, 1, 15, 21, 727364, tzinfo=TzInfo(UTC)),
│ │ │ │ output_attachments=[]
│ │ │ )
│ │ )
│ )
)
Streaming with print helper...
inference> Déjà vu!
As I mentioned earlier, I'm an artificial intelligence language model. I don't have a personal identity or consciousness like humans do. I exist solely to process and respond to text-based inputs, providing information and assistance on a wide range of topics.
I'm a computer program designed to simulate human-like conversations, using natural language processing (NLP) and machine learning algorithms to understand and generate responses. My purpose is to help users like you with their questions, provide information, and engage in conversation.
Think of me as a virtual companion, a helpful tool designed to make your interactions more efficient and enjoyable. I don't have personal opinions, emotions, or biases, but I'm here to provide accurate and informative responses to the best of my abilities.
So, who am I? I'm just a computer program designed to help you!
```
:::
#### 4.3. RAG agent
Create a file `rag_agent.py` and add the following code:
```python
from llama_stack_client import LlamaStackClient
from llama_stack_client import Agent, AgentEventLogger
from llama_stack_client.types import Document
import uuid
client = LlamaStackClient(base_url=f"http://localhost:8321")
# Create a vector database instance
embedlm = next(m for m in client.models.list() if m.model_type == "embedding")
embedding_model = embedlm.identifier
vector_db_id = f"v{uuid.uuid4().hex}"
client.vector_dbs.register(
_ = client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
provider_id="faiss",
)
source = "https://www.paulgraham.com/greatwork.html"
print("rag_tool> Ingesting document:", source)
document = RAGDocument(
document_id="document_1",
content=source,
mime_type="text/html",
metadata={},
)
# Create Documents
urls = [
"memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
documents = [
Document(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
# Insert documents
client.tool_runtime.rag_tool.insert(
documents=documents,
documents=[document],
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
chunk_size_in_tokens=50,
)
# Get the model being served
llm = next(m for m in client.models.list() if m.model_type == "llm")
model = llm.identifier
# Create RAG agent
ragagent = Agent(
agent = Agent(
client,
model=model,
instructions="You are a helpful assistant. Use the RAG tool to answer questions as needed.",
model=model_id,
instructions="You are a helpful assistant",
tools=[
{
"name": "builtin::rag/knowledge_search",
@ -417,39 +68,54 @@ ragagent = Agent(
],
)
s_id = ragagent.create_session(session_name=f"s{uuid.uuid4().hex}")
prompt = "How do you do great work?"
print("prompt>", prompt)
turns = ["what is torchtune", "tell me about dora"]
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=agent.create_session("rag_session"),
stream=True,
)
for t in turns:
print("user>", t)
stream = ragagent.create_turn(
messages=[{"role": "user", "content": t}], session_id=s_id, stream=True
)
for event in AgentEventLogger().log(stream):
event.print()
for log in AgentEventLogger().log(response):
log.print()
```
Run the script:
We will use `uv` to run the script
```
python rag_agent.py
uv run --with llama-stack-client demo_script.py
```
:::{dropdown} `Sample output`
And you should see output like below.
```
user> what is torchtune
inference> [knowledge_search(query='TorchTune')]
tool_execution> Tool:knowledge_search Args:{'query': 'TorchTune'}
tool_execution> Tool:knowledge_search Response:[TextContentItem(text='knowledge_search tool found 5 chunks:\nBEGIN of knowledge_search tool results.\n', type='text'), TextContentItem(text='Result 1:\nDocument_id:num-1\nContent: conversational data, :func:`~torchtune.datasets.chat_dataset` seems to be a good fit. ..., type='text'), TextContentItem(text='END of knowledge_search tool results.\n', type='text')]
inference> Here is a high-level overview of the text:
rag_tool> Ingesting document: https://www.paulgraham.com/greatwork.html
**LoRA Finetuning with PyTorch Tune**
prompt> How do you do great work?
PyTorch Tune provides a recipe for LoRA (Low-Rank Adaptation) finetuning, which is a technique to adapt pre-trained models to new tasks. The recipe uses the `lora_finetune_distributed` command.
...
Overall, DORA is a powerful reinforcement learning algorithm that can learn complex tasks from human demonstrations. However, it requires careful consideration of the challenges and limitations to achieve optimal results.
inference> [knowledge_search(query="What is the key to doing great work")]
tool_execution> Tool:knowledge_search Args:{'query': 'What is the key to doing great work'}
tool_execution> Tool:knowledge_search Response:[TextContentItem(text='knowledge_search tool found 5 chunks:\nBEGIN of knowledge_search tool results.\n', type='text'), TextContentItem(text="Result 1:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 2:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 3:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 4:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 5:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text='END of knowledge_search tool results.\n', type='text')]
inference> Based on the search results, it seems that doing great work means doing something important so well that you expand people's ideas of what's possible. However, there is no clear threshold for importance, and it can be difficult to judge at the time.
To further clarify, I would suggest that doing great work involves:
* Completing tasks with high quality and attention to detail
* Expanding on existing knowledge or ideas
* Making a positive impact on others through your work
* Striving for excellence and continuous improvement
Ultimately, great work is about making a meaningful contribution and leaving a lasting impression.
```
:::
Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳
## Next Steps
- Go through the [Getting Started Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb)
- Checkout more [Notebooks on GitHub](https://github.com/meta-llama/llama-stack/tree/main/docs/notebooks)
- See [References](../references/index.md) for more details about the llama CLI and Python SDK
- For example applications and more detailed tutorials, visit our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository.
Now you're ready to dive deeper into Llama Stack!
- Explore the [Detailed Tutorial](./detailed_tutorial.md).
- Try the [Getting Started Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb).
- Browse more [Notebooks on GitHub](https://github.com/meta-llama/llama-stack/tree/main/docs/notebooks).
- Learn about Llama Stack [Concepts](../concepts/index.md).
- Discover how to [Build Llama Stacks](../distributions/index.md).
- Refer to our [References](../references/index.md) for details on the Llama CLI and Python SDK.
- Check out the [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository for example applications and tutorials.

View file

@ -1,3 +1,5 @@
# Llama Stack
Welcome to Llama Stack, the open-source framework for building generative AI applications.
```{admonition} Llama 4 is here!
:class: tip
@ -9,7 +11,6 @@ Check out [Getting Started with Llama 4](https://colab.research.google.com/githu
Llama Stack {{ llama_stack_version }} is now available! See the {{ llama_stack_version_link }} for more details.
```
# Llama Stack
## What is Llama Stack?
@ -98,8 +99,9 @@ A number of "adapters" are available for some popular Inference and Vector Store
:maxdepth: 3
self
introduction/index
getting_started/index
getting_started/detailed_tutorial
introduction/index
concepts/index
providers/index
distributions/index

View file

@ -103,7 +103,5 @@ llama stack run together
2. Start Streamlit UI
```bash
cd llama_stack/distribution/ui
pip install -r requirements.txt
streamlit run app.py
uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
```

View file

@ -50,9 +50,10 @@ Llama Stack supports two types of external providers:
Here's a list of known external providers that you can use with Llama Stack:
| Type | Name | Description | Repository |
|------|------|-------------|------------|
| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
| Name | Description | API | Type | Repository |
|------|-------------|-----|------|------------|
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
| RamaLama | Inference models with RamaLama | Inference | Remote | [llama-stack-provider-ramalama](https://github.com/containers/llama-stack-provider-ramalama) |
### Remote Provider Specification

View file

@ -1,8 +1,8 @@
# Providers Overview
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
- LLM inference providers (e.g., Ollama, Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, SQLite-Vec, etc.),
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
Providers come in two flavors:

View file

@ -225,8 +225,18 @@ class AgentConfigCommon(BaseModel):
@json_schema_type
class AgentConfig(AgentConfigCommon):
"""Configuration for an agent.
:param model: The model identifier to use for the agent
:param instructions: The system instructions for the agent
:param name: Optional name for the agent, used in telemetry and identification
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
:param response_format: Optional response format configuration
"""
model: str
instructions: str
name: Optional[str] = None
enable_session_persistence: Optional[bool] = False
response_format: Optional[ResponseFormat] = None

View file

@ -6,11 +6,8 @@
from typing import List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionResponse,
InterleavedContent,
LogProbConfig,
Message,
@ -20,41 +17,39 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
from llama_stack.schema_utils import webmethod
@runtime_checkable
class BatchInference(Protocol):
"""Batch inference API for generating completions and chat completions.
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
including (post-training, evals, etc).
"""
@webmethod(route="/batch-inference/completion", method="POST")
async def batch_completion(
async def completion(
self,
model: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
) -> Job: ...
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def batch_chat_completion(
async def chat_completion(
self,
model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ...
) -> Job: ...

View file

@ -18,7 +18,7 @@ from typing import (
)
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from typing_extensions import Annotated, TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model
@ -442,6 +442,352 @@ class EmbeddingsResponse(BaseModel):
embeddings: List[List[float]]
@json_schema_type
class OpenAIChatCompletionContentPartTextParam(BaseModel):
type: Literal["text"] = "text"
text: str
@json_schema_type
class OpenAIImageURL(BaseModel):
url: str
detail: Optional[str] = None
@json_schema_type
class OpenAIChatCompletionContentPartImageParam(BaseModel):
type: Literal["image_url"] = "image_url"
image_url: OpenAIImageURL
OpenAIChatCompletionContentPartParam = Annotated[
Union[
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionContentPartImageParam,
],
Field(discriminator="type"),
]
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
@json_schema_type
class OpenAIUserMessageParam(BaseModel):
"""A message from the user in an OpenAI-compatible chat completion request.
:param role: Must be "user" to identify this as a user message
:param content: The content of the message, which can include text and other media
:param name: (Optional) The name of the user message participant.
"""
role: Literal["user"] = "user"
content: OpenAIChatCompletionMessageContent
name: Optional[str] = None
@json_schema_type
class OpenAISystemMessageParam(BaseModel):
"""A system message providing instructions or context to the model.
:param role: Must be "system" to identify this as a system message
:param content: The content of the "system prompt". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions).
:param name: (Optional) The name of the system message participant.
"""
role: Literal["system"] = "system"
content: OpenAIChatCompletionMessageContent
name: Optional[str] = None
@json_schema_type
class OpenAIChatCompletionToolCallFunction(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
@json_schema_type
class OpenAIChatCompletionToolCall(BaseModel):
index: Optional[int] = None
id: Optional[str] = None
type: Literal["function"] = "function"
function: Optional[OpenAIChatCompletionToolCallFunction] = None
@json_schema_type
class OpenAIAssistantMessageParam(BaseModel):
"""A message containing the model's (assistant) response in an OpenAI-compatible chat completion request.
:param role: Must be "assistant" to identify this as the model's response
:param content: The content of the model's response
:param name: (Optional) The name of the assistant message participant.
:param tool_calls: List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object.
"""
role: Literal["assistant"] = "assistant"
content: Optional[OpenAIChatCompletionMessageContent] = None
name: Optional[str] = None
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
@json_schema_type
class OpenAIToolMessageParam(BaseModel):
"""A message representing the result of a tool invocation in an OpenAI-compatible chat completion request.
:param role: Must be "tool" to identify this as a tool response
:param tool_call_id: Unique identifier for the tool call this response is for
:param content: The response content from the tool
"""
role: Literal["tool"] = "tool"
tool_call_id: str
content: OpenAIChatCompletionMessageContent
@json_schema_type
class OpenAIDeveloperMessageParam(BaseModel):
"""A message from the developer in an OpenAI-compatible chat completion request.
:param role: Must be "developer" to identify this as a developer message
:param content: The content of the developer message
:param name: (Optional) The name of the developer message participant.
"""
role: Literal["developer"] = "developer"
content: OpenAIChatCompletionMessageContent
name: Optional[str] = None
OpenAIMessageParam = Annotated[
Union[
OpenAIUserMessageParam,
OpenAISystemMessageParam,
OpenAIAssistantMessageParam,
OpenAIToolMessageParam,
OpenAIDeveloperMessageParam,
],
Field(discriminator="role"),
]
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
@json_schema_type
class OpenAIResponseFormatText(BaseModel):
type: Literal["text"] = "text"
@json_schema_type
class OpenAIJSONSchema(TypedDict, total=False):
name: str
description: Optional[str] = None
strict: Optional[bool] = None
# Pydantic BaseModel cannot be used with a schema param, since it already
# has one. And, we don't want to alias here because then have to handle
# that alias when converting to OpenAI params. So, to support schema,
# we use a TypedDict.
schema: Optional[Dict[str, Any]] = None
@json_schema_type
class OpenAIResponseFormatJSONSchema(BaseModel):
type: Literal["json_schema"] = "json_schema"
json_schema: OpenAIJSONSchema
@json_schema_type
class OpenAIResponseFormatJSONObject(BaseModel):
type: Literal["json_object"] = "json_object"
OpenAIResponseFormatParam = Annotated[
Union[
OpenAIResponseFormatText,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatJSONObject,
],
Field(discriminator="type"),
]
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
@json_schema_type
class OpenAITopLogProb(BaseModel):
"""The top log probability for a token from an OpenAI-compatible chat completion response.
:token: The token
:bytes: (Optional) The bytes for the token
:logprob: The log probability of the token
"""
token: str
bytes: Optional[List[int]] = None
logprob: float
@json_schema_type
class OpenAITokenLogProb(BaseModel):
"""The log probability for a token from an OpenAI-compatible chat completion response.
:token: The token
:bytes: (Optional) The bytes for the token
:logprob: The log probability of the token
:top_logprobs: The top log probabilities for the token
"""
token: str
bytes: Optional[List[int]] = None
logprob: float
top_logprobs: List[OpenAITopLogProb]
@json_schema_type
class OpenAIChoiceLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response.
:param content: (Optional) The log probabilities for the tokens in the message
:param refusal: (Optional) The log probabilities for the tokens in the message
"""
content: Optional[List[OpenAITokenLogProb]] = None
refusal: Optional[List[OpenAITokenLogProb]] = None
@json_schema_type
class OpenAIChoiceDelta(BaseModel):
"""A delta from an OpenAI-compatible chat completion streaming response.
:param content: (Optional) The content of the delta
:param refusal: (Optional) The refusal of the delta
:param role: (Optional) The role of the delta
:param tool_calls: (Optional) The tool calls of the delta
"""
content: Optional[str] = None
refusal: Optional[str] = None
role: Optional[str] = None
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
@json_schema_type
class OpenAIChunkChoice(BaseModel):
"""A chunk choice from an OpenAI-compatible chat completion streaming response.
:param delta: The delta from the chunk
:param finish_reason: The reason the model stopped generating
:param index: The index of the choice
:param logprobs: (Optional) The log probabilities for the tokens in the message
"""
delta: OpenAIChoiceDelta
finish_reason: str
index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None
@json_schema_type
class OpenAIChoice(BaseModel):
"""A choice from an OpenAI-compatible chat completion response.
:param message: The message from the model
:param finish_reason: The reason the model stopped generating
:param index: The index of the choice
:param logprobs: (Optional) The log probabilities for the tokens in the message
"""
message: OpenAIMessageParam
finish_reason: str
index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None
@json_schema_type
class OpenAIChatCompletion(BaseModel):
"""Response from an OpenAI-compatible chat completion request.
:param id: The ID of the chat completion
:param choices: List of choices
:param object: The object type, which will be "chat.completion"
:param created: The Unix timestamp in seconds when the chat completion was created
:param model: The model that was used to generate the chat completion
"""
id: str
choices: List[OpenAIChoice]
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
@json_schema_type
class OpenAIChatCompletionChunk(BaseModel):
"""Chunk from a streaming response to an OpenAI-compatible chat completion request.
:param id: The ID of the chat completion
:param choices: List of choices
:param object: The object type, which will be "chat.completion.chunk"
:param created: The Unix timestamp in seconds when the chat completion was created
:param model: The model that was used to generate the chat completion
"""
id: str
choices: List[OpenAIChunkChoice]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
@json_schema_type
class OpenAICompletionLogprobs(BaseModel):
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
:text_offset: (Optional) The offset of the token in the text
:token_logprobs: (Optional) The log probabilities for the tokens
:tokens: (Optional) The tokens
:top_logprobs: (Optional) The top log probabilities for the tokens
"""
text_offset: Optional[List[int]] = None
token_logprobs: Optional[List[float]] = None
tokens: Optional[List[str]] = None
top_logprobs: Optional[List[Dict[str, float]]] = None
@json_schema_type
class OpenAICompletionChoice(BaseModel):
"""A choice from an OpenAI-compatible completion response.
:finish_reason: The reason the model stopped generating
:text: The text of the choice
:index: The index of the choice
:logprobs: (Optional) The log probabilities for the tokens in the choice
"""
finish_reason: str
text: str
index: int
logprobs: Optional[OpenAIChoiceLogprobs] = None
@json_schema_type
class OpenAICompletion(BaseModel):
"""Response from an OpenAI-compatible completion request.
:id: The ID of the completion
:choices: List of choices
:created: The Unix timestamp in seconds when the completion was created
:model: The model that was used to generate the completion
:object: The object type, which will be "text_completion"
"""
id: str
choices: List[OpenAICompletionChoice]
created: int
model: str
object: Literal["text_completion"] = "text_completion"
class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ...
@ -470,6 +816,16 @@ class EmbeddingTaskType(Enum):
document = "document"
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
@runtime_checkable
@trace_protocol
class Inference(Protocol):
@ -505,6 +861,17 @@ class Inference(Protocol):
"""
...
@webmethod(route="/inference/batch-completion", method="POST", experimental=True)
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
raise NotImplementedError("Batch completion is not implemented")
@webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion(
self,
@ -545,6 +912,19 @@ class Inference(Protocol):
"""
...
@webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True)
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse:
raise NotImplementedError("Batch chat completion is not implemented")
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
self,
@ -564,3 +944,105 @@ class Inference(Protocol):
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
"""
...
@webmethod(route="/openai/v1/completions", method="POST")
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
# vLLM-specific parameters
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for
:param best_of: (Optional) The number of completions to generate
:param echo: (Optional) Whether to echo the prompt
:param frequency_penalty: (Optional) The penalty for repeated tokens
:param logit_bias: (Optional) The logit bias to use
:param logprobs: (Optional) The log probabilities to use
:param max_tokens: (Optional) The maximum number of tokens to generate
:param n: (Optional) The number of completions to generate
:param presence_penalty: (Optional) The penalty for repeated tokens
:param seed: (Optional) The seed to use
:param stop: (Optional) The stop tokens to use
:param stream: (Optional) Whether to stream the response
:param stream_options: (Optional) The stream options to use
:param temperature: (Optional) The temperature to use
:param top_p: (Optional) The top p to use
:param user: (Optional) The user to use
"""
...
@webmethod(route="/openai/v1/chat/completions", method="POST")
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation
:param frequency_penalty: (Optional) The penalty for repeated tokens
:param function_call: (Optional) The function call to use
:param functions: (Optional) List of functions to use
:param logit_bias: (Optional) The logit bias to use
:param logprobs: (Optional) The log probabilities to use
:param max_completion_tokens: (Optional) The maximum number of tokens to generate
:param max_tokens: (Optional) The maximum number of tokens to generate
:param n: (Optional) The number of completions to generate
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls
:param presence_penalty: (Optional) The penalty for repeated tokens
:param response_format: (Optional) The response format to use
:param seed: (Optional) The seed to use
:param stop: (Optional) The stop tokens to use
:param stream: (Optional) Whether to stream the response
:param stream_options: (Optional) The stream options to use
:param temperature: (Optional) The temperature to use
:param tool_choice: (Optional) The tool choice to use
:param tools: (Optional) The tools to use
:param top_logprobs: (Optional) The top log probabilities to use
:param top_p: (Optional) The top p to use
:param user: (Optional) The user to use
"""
...

View file

@ -8,6 +8,7 @@ from typing import List, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.schema_utils import json_schema_type, webmethod
@ -20,8 +21,7 @@ class RouteInfo(BaseModel):
@json_schema_type
class HealthInfo(BaseModel):
status: str
# TODO: add a provider level status
status: HealthStatus
@json_schema_type

View file

@ -56,12 +56,35 @@ class ListModelsResponse(BaseModel):
data: List[Model]
@json_schema_type
class OpenAIModel(BaseModel):
"""A model from OpenAI.
:id: The ID of the model
:object: The object type, which will be "model"
:created: The Unix timestamp in seconds when the model was created
:owned_by: The owner of the model
"""
id: str
object: Literal["model"] = "model"
created: int
owned_by: str
class OpenAIListModelsResponse(BaseModel):
data: List[OpenAIModel]
@runtime_checkable
@trace_protocol
class Models(Protocol):
@webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ...
@webmethod(route="/openai/v1/models", method="GET")
async def openai_list_models(self) -> OpenAIListModelsResponse: ...
@webmethod(route="/models/{model_id:path}", method="GET")
async def get_model(
self,

View file

@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel):
@json_schema_type
class TrainingConfig(BaseModel):
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
max_validation_steps: int
data_config: DataConfig
optimizer_config: OptimizerConfig
max_steps_per_epoch: int = 1
gradient_accumulation_steps: int = 1
max_validation_steps: Optional[int] = 1
data_config: Optional[DataConfig] = None
optimizer_config: Optional[OptimizerConfig] = None
efficiency_config: Optional[EfficiencyConfig] = None
dtype: Optional[str] = "bf16"
@ -177,9 +177,9 @@ class PostTraining(Protocol):
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
model: Optional[str] = Field(
default=None,
description="Model descriptor for training if not in provider config`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,

View file

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.providers.datatypes import HealthResponse
from llama_stack.schema_utils import json_schema_type, webmethod
@ -17,6 +18,7 @@ class ProviderInfo(BaseModel):
provider_id: str
provider_type: str
config: Dict[str, Any]
health: HealthResponse
class ListProvidersResponse(BaseModel):

View file

@ -89,6 +89,43 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
color="red",
)
sys.exit(1)
elif args.providers:
providers = dict()
for api_provider in args.providers.split(","):
if "=" not in api_provider:
cprint(
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
color="red",
)
sys.exit(1)
api, provider = api_provider.split("=")
providers_for_api = get_provider_registry().get(Api(api), None)
if providers_for_api is None:
cprint(
f"{api} is not a valid API.",
color="red",
)
sys.exit(1)
if provider in providers_for_api:
providers.setdefault(api, []).append(provider)
else:
cprint(
f"{provider} is not a valid provider for the {api} API.",
color="red",
)
sys.exit(1)
distribution_spec = DistributionSpec(
providers=providers,
description=",".join(args.providers),
)
if not args.image_type:
cprint(
f"Please specify a image-type (container | conda | venv) for {args.template}",
color="red",
)
sys.exit(1)
build_config = BuildConfig(image_type=args.image_type, distribution_spec=distribution_spec)
elif not args.config and not args.template:
name = prompt(
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
@ -173,16 +210,9 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
cprint(
"Please specify --image-name when building a container from a config file",
color="red",
)
sys.exit(1)
if args.print_deps_only:
print(f"# Dependencies for {args.template or args.config or image_name}")
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
normal_deps, special_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
print(f"uv pip install {' '.join(normal_deps)}")
for special_dep in special_deps:
@ -198,10 +228,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
except (Exception, RuntimeError) as exc:
import traceback
cprint(
f"Error building stack: {exc}",
color="red",
)
cprint("Stack trace:", color="red")
traceback.print_exc()
sys.exit(1)
if run_config is None:
cprint(
@ -233,9 +267,10 @@ def _generate_run_config(
image_name=image_name,
apis=apis,
providers={},
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None,
)
# build providers dict
provider_registry = get_provider_registry()
provider_registry = get_provider_registry(build_config)
for api in apis:
run_config.providers[api] = []
provider_types = build_config.distribution_spec.providers[api]
@ -249,8 +284,22 @@ def _generate_run_config(
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
if hasattr(config_type, "sample_run_config"):
try:
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
except ModuleNotFoundError:
# HACK ALERT:
# This code executes after building is done, the import cannot work since the
# package is either available in the venv or container - not available on the host.
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
# external
cprint(
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
color="yellow",
)
# Set config_type to None to avoid UnboundLocalError
config_type = None
if config_type is not None and hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
else:
config = {}
@ -282,6 +331,7 @@ def _run_stack_build_command_from_build_config(
template_name: Optional[str] = None,
config_path: Optional[str] = None,
) -> str:
image_name = image_name or build_config.image_name
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
if template_name:
image_name = f"distribution-{template_name}"
@ -313,7 +363,7 @@ def _run_stack_build_command_from_build_config(
build_config,
build_file_path,
image_name,
template_or_config=template_name or config_path,
template_or_config=template_name or config_path or str(build_file_path),
)
if return_code != 0:
raise RuntimeError(f"Failed to build image {image_name}")

View file

@ -57,7 +57,7 @@ class StackBuild(Subcommand):
type=str,
help=textwrap.dedent(
f"""[for image-type={"|".join(e.value for e in ImageType)}] Name of the conda or virtual environment to use for
the build. If not specified, currently active Conda environment will be used if found.
the build. If not specified, currently active environment will be used if found.
"""
),
default=None,
@ -75,6 +75,12 @@ the build. If not specified, currently active Conda environment will be used if
default=False,
help="Run the stack after building using the same image type, name, and other applicable arguments",
)
self.parser.add_argument(
"--providers",
type=str,
default=None,
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
# always keep implementation completely silo-ed away from CLI so CLI

View file

@ -45,7 +45,7 @@ class StackRun(Subcommand):
"--image-name",
type=str,
default=os.environ.get("CONDA_DEFAULT_ENV"),
help="Name of the image to run. Defaults to the current conda environment",
help="Name of the image to run. Defaults to the current environment",
)
self.parser.add_argument(
"--disable-ipv6",

View file

@ -7,16 +7,16 @@
import importlib.resources
import logging
from pathlib import Path
from typing import Dict, List
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.datatypes import BuildConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.exec import run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
from llama_stack.templates.template import DistributionTemplate
log = logging.getLogger(__name__)
@ -37,19 +37,24 @@ class ApiInput(BaseModel):
def get_provider_dependencies(
config_providers: Dict[str, List[Provider]],
config: BuildConfig | DistributionTemplate,
) -> tuple[list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
all_providers = get_provider_registry()
# Extract providers based on config type
if isinstance(config, DistributionTemplate):
providers = config.providers
elif isinstance(config, BuildConfig):
providers = config.distribution_spec.providers
deps = []
registry = get_provider_registry(config)
for api_str, provider_or_providers in config_providers.items():
providers_for_api = all_providers[Api(api_str)]
for api_str, provider_or_providers in providers.items():
providers_for_api = registry[Api(api_str)]
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
for provider in providers:
# Providers from BuildConfig and RunConfig are subtly different  not great
# Providers from BuildConfig and RunConfig are subtly different not great
provider_type = provider if isinstance(provider, str) else provider.provider_type
if provider_type not in providers_for_api:
@ -71,8 +76,8 @@ def get_provider_dependencies(
return list(set(normal_deps)), list(set(special_deps))
def print_pip_install_help(providers: Dict[str, List[Provider]]):
normal_deps, special_deps = get_provider_dependencies(providers)
def print_pip_install_help(config: BuildConfig):
normal_deps, special_deps = get_provider_dependencies(config)
cprint(
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
@ -91,7 +96,7 @@ def build_image(
):
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
normal_deps, special_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == LlamaStackImageType.CONTAINER.value:

View file

@ -72,9 +72,13 @@ if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
FROM $container_base
WORKDIR /app
# We install the Python 3.11 dev headers and build tools so that any
# Cextension wheels (e.g. polyleven, faisscpu) can compile successfully.
RUN dnf -y update && dnf install -y iputils net-tools wget \
vim-minimal python3.11 python3.11-pip python3.11-wheel \
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
python3.11-setuptools python3.11-devel gcc make && \
ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1
RUN pip install uv
@ -86,7 +90,7 @@ WORKDIR /app
RUN apt-get update && apt-get install -y \
iputils-ping net-tools iproute2 dnsutils telnet \
curl wget telnet \
curl wget telnet git\
procps psmisc lsof \
traceroute \
bubblewrap \

View file

@ -326,3 +326,12 @@ class BuildConfig(BaseModel):
default="conda",
description="Type of package to build (conda | container | venv)",
)
image_name: Optional[str] = Field(
default=None,
description="Name of the distribution to build",
)
external_providers_dir: Optional[str] = Field(
default=None,
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
"pip_packages MUST contain the provider package name.",
)

View file

@ -12,7 +12,6 @@ from typing import Any, Dict, List
import yaml
from pydantic import BaseModel
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
AdapterSpec,
@ -97,7 +96,9 @@ def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_nam
return spec
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
def get_provider_registry(
config=None,
) -> Dict[Api, Dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files.
@ -122,7 +123,7 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
llama-guard.yaml
Args:
config: Optional StackRunConfig containing the external providers directory path
config: Optional object containing the external providers directory path
Returns:
A dictionary mapping APIs to their available providers
@ -142,7 +143,8 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
if config and config.external_providers_dir:
# Check if config has the external_providers_dir attribute
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
external_providers_dir = os.path.abspath(config.external_providers_dir)
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")

View file

@ -17,6 +17,7 @@ from llama_stack.apis.inspect import (
)
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.providers.datatypes import HealthStatus
class DistributionInspectConfig(BaseModel):
@ -58,7 +59,7 @@ class DistributionInspectImpl(Inspect):
return ListRoutesResponse(data=ret)
async def health(self) -> HealthInfo:
return HealthInfo(status="OK")
return HealthInfo(status=HealthStatus.OK)
async def version(self) -> VersionInfo:
return VersionInfo(version=version("llama-stack"))

View file

@ -43,9 +43,9 @@ from llama_stack.distribution.server.endpoints import (
from llama_stack.distribution.stack import (
construct_stack,
get_stack_run_config_from_template,
redact_sensitive_fields,
replace_env_vars,
)
from llama_stack.distribution.utils.config import redact_sensitive_fields
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.distribution.utils.exec import in_notebook
from llama_stack.providers.utils.telemetry.tracing import (

View file

@ -4,14 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from .datatypes import StackRunConfig
from .stack import redact_sensitive_fields
from .utils.config import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
@ -41,19 +44,24 @@ class ProviderImpl(Providers):
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
providers_health = await self.get_providers_health()
ret = []
for api, providers in safe_config.providers.items():
ret.extend(
[
for p in providers:
ret.append(
ProviderInfo(
api=api,
provider_id=p.provider_id,
provider_type=p.provider_type,
config=p.config,
health=providers_health.get(api, {}).get(
p.provider_id,
HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
),
)
for p in providers
]
)
)
return ListProvidersResponse(data=ret)
@ -64,3 +72,57 @@ class ProviderImpl(Providers):
return p
raise ValueError(f"Provider {provider_id} not found")
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
"""Get health status for all providers.
Returns:
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
Each API maps to a dictionary of provider IDs to their health responses.
"""
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
timeout = 1.0
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
# Skip special implementations (inspect/providers) that don't have provider specs
if not hasattr(impl, "__provider_spec__"):
return None
api_name = impl.__provider_spec__.api.name
if not hasattr(impl, "health"):
return (
api_name,
HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
),
)
try:
health = await asyncio.wait_for(impl.health(), timeout=timeout)
return api_name, health
except asyncio.TimeoutError:
return (
api_name,
HealthResponse(
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
),
)
except Exception as e:
return (
api_name,
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
)
# Create tasks for all providers
tasks = [check_provider_health(impl) for impl in self.deps.values()]
# Wait for all health checks to complete
results = await asyncio.gather(*tasks)
# Organize results by API and provider ID
for result in results:
if result is None: # Skip special implementations
continue
api_name, health_response = result
providers_health[api_name] = health_response
return providers_health

View file

@ -41,7 +41,6 @@ from llama_stack.providers.datatypes import (
Api,
BenchmarksProtocolPrivate,
DatasetsProtocolPrivate,
InlineProviderSpec,
ModelsProtocolPrivate,
ProviderSpec,
RemoteProviderConfig,
@ -230,50 +229,9 @@ def sort_providers_by_deps(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
# Append built-in "inspect" provider
apis = [x[1].spec.api for x in sorted_providers]
sorted_providers.append(
(
"inspect",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={"run_config": run_config.model_dump()},
spec=InlineProviderSpec(
api=Api.inspect,
provider_type="__builtin__",
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect",
api_dependencies=apis,
deps__=[x.value for x in apis],
),
),
)
)
sorted_providers.append(
(
"providers",
ProviderWithSpec(
provider_id="__builtin__",
provider_type="__builtin__",
config={"run_config": run_config.model_dump()},
spec=InlineProviderSpec(
api=Api.providers,
provider_type="__builtin__",
config_class="llama_stack.distribution.providers.ProviderImplConfig",
module="llama_stack.distribution.providers",
api_dependencies=apis,
deps__=[x.value for x in apis],
),
),
)
)
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logger.debug(f" {api_str} => {provider.provider_id}")
logger.debug("")
return sorted_providers
@ -400,6 +358,8 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
mro = type(obj).__mro__
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
if value.__webmethod__.experimental:
continue
if not hasattr(obj, name):
missing_methods.append((name, "missing"))
elif not callable(getattr(obj, name)):

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
@ -17,6 +18,8 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
@ -35,6 +38,13 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
@ -57,7 +67,7 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
@ -333,6 +343,30 @@ class InferenceRouter(Inference):
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse:
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion(
model_id=model_id,
messages_batch=messages_batch,
tools=tools,
tool_config=tool_config,
sampling_params=sampling_params,
response_format=response_format,
logprobs=logprobs,
)
async def completion(
self,
model_id: str,
@ -397,6 +431,20 @@ class InferenceRouter(Inference):
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings(
self,
model_id: str,
@ -419,6 +467,149 @@ class InferenceRouter(Inference):
task_type=task_type,
)
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ValueError(f"Model '{model}' not found")
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
params = dict(
model=model_obj.identifier,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_completion(**params)
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ValueError(f"Model '{model}' not found")
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
params = dict(
model=model_obj.identifier,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
provider = self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_chat_completion(**params)
async def health(self) -> Dict[str, HealthResponse]:
health_statuses = {}
timeout = 0.5
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
try:
# check if the provider has a health method
if not hasattr(impl, "health"):
continue
health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health
except asyncio.TimeoutError:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses
class SafetyRouter(Safety):
def __init__(

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import logging
import time
import uuid
from typing import Any, Dict, List, Optional
@ -23,7 +24,7 @@ from llama_stack.apis.datasets import (
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
@ -254,6 +255,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
openai_models = [
OpenAIModel(
id=model.identifier,
object="model",
created=int(time.time()),
owned_by="llama_stack",
)
for model in models
]
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
model = await self.get_object_by_identifier("model", model_id)
if model is None:

View file

@ -38,10 +38,10 @@ from llama_stack.distribution.server.endpoints import (
)
from llama_stack.distribution.stack import (
construct_stack,
redact_sensitive_fields,
replace_env_vars,
validate_env_pair,
)
from llama_stack.distribution.utils.config import redact_sensitive_fields
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
@ -92,7 +92,7 @@ async def global_exception_handler(request: Request, exc: Exception):
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
if isinstance(exc, ValidationError):
exc = RequestValidationError(exc.raw_errors)
exc = RequestValidationError(exc.errors())
if isinstance(exc, RequestValidationError):
return HTTPException(
@ -162,9 +162,10 @@ async def maybe_await(value):
return value
async def sse_generator(event_gen):
async def sse_generator(event_gen_coroutine):
event_gen = await event_gen_coroutine
try:
async for item in await event_gen:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
@ -229,15 +230,30 @@ class TracingMiddleware:
def __init__(self, app, impls):
self.app = app
self.impls = impls
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
async def __call__(self, scope, receive, send):
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
# Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "endpoint_impls"):
self.endpoint_impls = initialize_endpoint_impls(self.impls)
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
try:
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
@ -388,7 +404,12 @@ def main(args: Optional[argparse.Namespace] = None):
safe_config = redact_sensitive_fields(config.model_dump())
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app = FastAPI(
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)

View file

@ -35,6 +35,8 @@ from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -96,7 +98,10 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
method = getattr(impls[api], register_method)
for obj in objects:
await method(**obj.model_dump())
# we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
method = getattr(impls[api], list_method)
response = await method()
@ -116,26 +121,6 @@ class EnvVarError(Exception):
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _redact_dict(v)
elif isinstance(v, list):
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
elif any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = v
return result
return _redact_dict(data)
def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict):
result = {}
@ -212,6 +197,26 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
) from e
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
Args:
impls: Dictionary of API implementations
run_config: Stack run configuration
"""
inspect_impl = DistributionInspectImpl(
DistributionInspectConfig(run_config=run_config),
deps=impls,
)
impls[Api.inspect] = inspect_impl
providers_impl = ProviderImpl(
ProviderImplConfig(run_config=run_config),
deps=impls,
)
impls[Api.providers] = providers_impl
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(
@ -219,6 +224,10 @@ async def construct_stack(
) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
await register_resources(run_config, impls)
return impls

View file

@ -18,6 +18,7 @@ VIRTUAL_ENV=${VIRTUAL_ENV:-}
set -euo pipefail
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
error_handler() {
@ -73,7 +74,7 @@ done
PYTHON_BINARY="python"
case "$env_type" in
"venv")
if [ -n "$VIRTUAL_ENV" && "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
if [ -n "$VIRTUAL_ENV" ] && [ "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
echo -e "${GREEN}Virtual environment already activated${NC}" >&2
else
# Activate virtual environment

View file

@ -36,9 +36,7 @@ llama-stack-client benchmarks register \
3. Start Streamlit UI
```bash
cd llama_stack/distribution/ui
pip install -r requirements.txt
streamlit run app.py
uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
```
## Environment Variables

View file

@ -9,6 +9,7 @@ import uuid
import streamlit as st
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack.apis.common.content_types import ToolCallDelta
from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import data_url_from_file
@ -16,9 +17,23 @@ from llama_stack.distribution.ui.modules.utils import data_url_from_file
def rag_chat_page():
st.title("🦙 RAG")
def reset_agent_and_chat():
st.session_state.clear()
st.cache_resource.clear()
def should_disable_input():
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
def log_message(message):
with st.chat_message(message["role"]):
if "tool_output" in message and message["tool_output"]:
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
st.write(message["tool_output"])
st.markdown(message["content"])
with st.sidebar:
# File/Directory Upload Section
st.subheader("Upload Documents")
st.subheader("Upload Documents", divider=True)
uploaded_files = st.file_uploader(
"Upload file(s) or directory",
accept_multiple_files=True,
@ -29,11 +44,11 @@ def rag_chat_page():
st.success(f"Successfully uploaded {len(uploaded_files)} files")
# Add memory bank name input field
vector_db_name = st.text_input(
"Vector Database Name",
"Document Collection Name",
value="rag_vector_db",
help="Enter a unique identifier for this vector database",
help="Enter a unique identifier for this document collection",
)
if st.button("Create Vector Database"):
if st.button("Create Document Collection"):
documents = [
RAGDocument(
document_id=uploaded_file.name,
@ -64,26 +79,45 @@ def rag_chat_page():
)
st.success("Vector database created successfully!")
st.subheader("Configure Agent")
st.subheader("RAG Parameters", divider=True)
rag_mode = st.radio(
"RAG mode",
["Direct", "Agent-based"],
captions=[
"RAG is performed by directly retrieving the information and augmenting the user query",
"RAG is performed by an agent activating a dedicated knowledge search tool.",
],
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# select memory banks
vector_dbs = llama_stack_api.client.vector_dbs.list()
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
selected_vector_dbs = st.multiselect(
"Select Vector Databases",
vector_dbs,
label="Select Document Collections to use in RAG queries",
options=vector_dbs,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
st.subheader("Inference Parameters", divider=True)
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
selected_model = st.selectbox(
"Choose a model",
available_models,
label="Choose a model",
options=available_models,
index=0,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
system_prompt = st.text_area(
"System Prompt",
value="You are a helpful assistant. ",
help="Initial instructions given to the AI to set its behavior and context",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
temperature = st.slider(
"Temperature",
@ -92,6 +126,8 @@ def rag_chat_page():
value=0.0,
step=0.1,
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
top_p = st.slider(
@ -100,21 +136,24 @@ def rag_chat_page():
max_value=1.0,
value=0.95,
step=0.1,
on_change=reset_agent_and_chat,
disabled=should_disable_input(),
)
# Add clear chat button to sidebar
if st.button("Clear Chat", use_container_width=True):
st.session_state.clear()
st.cache_resource.clear()
reset_agent_and_chat()
st.rerun()
# Chat Interface
if "messages" not in st.session_state:
st.session_state.messages = []
if "displayed_messages" not in st.session_state:
st.session_state.displayed_messages = []
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
for message in st.session_state.displayed_messages:
log_message(message)
if temperature > 0.0:
strategy = {
@ -144,22 +183,18 @@ def rag_chat_page():
],
)
agent = create_agent()
if rag_mode == "Agent-based":
agent = create_agent()
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
if "agent_session_id" not in st.session_state:
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
session_id = st.session_state["agent_session_id"]
session_id = st.session_state["agent_session_id"]
# Chat input
if prompt := st.chat_input("Ask a question about your documents"):
def agent_process_prompt(prompt):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Send the prompt to the agent
response = agent.create_turn(
messages=[
{
@ -172,7 +207,7 @@ def rag_chat_page():
# Display assistant response
with st.chat_message("assistant"):
retrieval_message_placeholder = st.empty()
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
@ -180,13 +215,87 @@ def rag_chat_page():
log.print()
if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response)
retrieval_message_placeholder.write(retrieval_response)
else:
full_response += log.content
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
st.session_state.displayed_messages.append(
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
)
def direct_process_prompt(prompt):
# Add the system prompt in the beginning of the conversation
if len(st.session_state.messages) == 0:
st.session_state.messages.append({"role": "system", "content": system_prompt})
# Query the vector DB
rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
content=prompt, vector_db_ids=list(selected_vector_dbs)
)
prompt_context = rag_response.content
with st.chat_message("assistant"):
with st.expander(label="Retrieval Output", expanded=False):
st.write(prompt_context)
retrieval_message_placeholder = st.empty()
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
# Construct the extended prompt
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
# Run inference directly
st.session_state.messages.append({"role": "user", "content": extended_prompt})
response = llama_stack_api.client.inference.chat_completion(
messages=st.session_state.messages,
model_id=selected_model,
sampling_params={
"strategy": strategy,
},
stream=True,
)
# Display assistant response
for chunk in response:
response_delta = chunk.event.delta
if isinstance(response_delta, ToolCallDelta):
retrieval_response += response_delta.tool_call.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response)
else:
full_response += chunk.event.delta.text
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
st.session_state.messages.append(response_dict)
st.session_state.displayed_messages.append(response_dict)
# Chat input
if prompt := st.chat_input("Ask a question about your documents"):
# Add user message to chat history
st.session_state.displayed_messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# store the prompt to process it after page refresh
st.session_state.prompt = prompt
# force page refresh to disable the settings widgets
st.rerun()
if "prompt" in st.session_state and st.session_state.prompt is not None:
if rag_mode == "Agent-based":
agent_process_prompt(st.session_state.prompt)
else: # rag_mode == "Direct"
direct_process_prompt(st.session_state.prompt)
st.session_state.prompt = None
rag_chat_page()

View file

@ -29,17 +29,39 @@ def tool_chat_page():
st.cache_resource.clear()
with st.sidebar:
st.title("Configuration")
st.subheader("Model")
model = st.selectbox(label="models", options=model_list, on_change=reset_agent)
model = st.selectbox(label="Model", options=model_list, on_change=reset_agent, label_visibility="collapsed")
st.subheader("Available ToolGroups")
st.subheader("Builtin Tools")
toolgroup_selection = st.pills(
label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent
label="Built-in tools",
options=builtin_tools_list,
selection_mode="multi",
on_change=reset_agent,
format_func=lambda tool: "".join(tool.split("::")[1:]),
help="List of built-in tools from your llama stack server.",
)
st.subheader("MCP Servers")
if "builtin::rag" in toolgroup_selection:
vector_dbs = llama_stack_api.client.vector_dbs.list() or []
if not vector_dbs:
st.info("No vector databases available for selection.")
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
selected_vector_dbs = st.multiselect(
label="Select Document Collections to use in RAG queries",
options=vector_dbs,
on_change=reset_agent,
)
mcp_selection = st.pills(
label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent
label="MCP Servers",
options=mcp_tools_list,
selection_mode="multi",
on_change=reset_agent,
format_func=lambda tool: "".join(tool.split("::")[1:]),
help="List of MCP servers registered to your llama stack server.",
)
toolgroup_selection.extend(mcp_selection)
@ -53,9 +75,30 @@ def tool_chat_page():
]
)
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
st.markdown(f"Active Tools: 🛠 {len(active_tool_list)}", help="List of currently active tools.")
st.json(active_tool_list)
st.subheader("Agent Configurations")
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=1,
help="The maximum number of tokens to generate",
on_change=reset_agent,
)
for i, tool_name in enumerate(toolgroup_selection):
if tool_name == "builtin::rag":
tool_dict = dict(
name="builtin::rag",
args={
"vector_db_ids": list(selected_vector_dbs),
},
)
toolgroup_selection[i] = tool_dict
@st.cache_resource
def create_agent():
return Agent(
@ -63,9 +106,7 @@ def tool_chat_page():
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={
"strategy": {"type": "greedy"},
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
agent = create_agent()
@ -103,7 +144,11 @@ def tool_chat_page():
yield response.event.payload.delta.text
if response.event.payload.event_type == "step_complete":
if response.event.payload.step_details.step_type == "tool_execution":
yield " 🛠 "
if response.event.payload.step_details.tool_calls:
tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name)
yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n'
else:
yield "No tool_calls present in step_details"
else:
yield f"Error occurred in the Llama Stack Cluster: {response}"

View file

@ -1,5 +1,5 @@
streamlit
pandas
llama-stack-client>=0.0.55
llama-stack-client>=0.2.1
streamlit-option-menu
llama-stack>=0.1.9
llama-stack>=0.2.1

View file

@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
def _redact_value(v: Any) -> Any:
if isinstance(v, dict):
return _redact_dict(v)
elif isinstance(v, list):
return [_redact_value(i) for i in v]
return v
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
if any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = _redact_value(v)
return result
return _redact_dict(data)

View file

@ -226,7 +226,6 @@ class ChatFormat:
arguments_json=json.dumps(tool_arguments),
)
)
content = ""
return RawMessage(
role="assistant",

View file

@ -140,7 +140,12 @@ class Llama3:
return Llama3(model, tokenizer, model_args)
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
def __init__(
self,
model: Transformer | CrossAttentionTransformer,
tokenizer: Tokenizer,
args: ModelArgs,
):
self.args = args
self.model = model
self.tokenizer = tokenizer
@ -149,7 +154,7 @@ class Llama3:
@torch.inference_mode()
def generate(
self,
model_inputs: List[LLMInput],
llm_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
@ -164,15 +169,15 @@ class Llama3:
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
for inp in model_inputs:
for inp in llm_inputs:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [inp.tokens for inp in model_inputs]
prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = len(model_inputs)
bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
@ -193,8 +198,8 @@ class Llama3:
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=images,
@ -229,7 +234,7 @@ class Llama3:
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = all(inp.vision is None for inp in model_inputs)
text_only_inference = all(inp.vision is None for inp in llm_inputs)
logits = self.model.forward(
position_ids,
tokens,
@ -285,7 +290,7 @@ class Llama3:
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)

View file

@ -229,6 +229,11 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
If you decide to invoke a function, you SHOULD NOT include any other text in the response. besides the function call in the above format.
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
{{ function_description }}
""".strip("\n")
)
@ -243,10 +248,6 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import ast
import json
import re
from typing import Optional, Tuple
@ -35,80 +28,141 @@ def is_json(s):
return True
def is_valid_python_list(input_string):
"""Check if the input string is a valid Python list of function calls"""
try:
# Try to parse the string
tree = ast.parse(input_string)
# Check if it's a single expression
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
return False
# Check if the expression is a list
expr = tree.body[0].value
if not isinstance(expr, ast.List):
return False
# Check if the list is empty
if len(expr.elts) == 0:
return False
# Check if all elements in the list are function calls
for element in expr.elts:
if not isinstance(element, ast.Call):
return False
# Check if the function call has a valid name
if not isinstance(element.func, ast.Name):
return False
# Check if all arguments are keyword arguments
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
return False
return True
except SyntaxError:
# If parsing fails, it's not a valid Python expression
return False
def parse_python_list_for_function_calls(input_string):
def parse_llama_tool_call_format(input_string):
"""
Parse a Python list of function calls and
return a list of tuples containing the function name and arguments
"""
# Parse the string into an AST
tree = ast.parse(input_string)
Parse tool calls in the format:
[func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
# Ensure the input is a list
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
raise ValueError("Input must be a list of function calls")
Returns a list of (function_name, arguments_dict) tuples or None if parsing fails.
"""
# Strip outer brackets and whitespace
input_string = input_string.strip()
if not (input_string.startswith("[") and input_string.endswith("]")):
return None
content = input_string[1:-1].strip()
if not content:
return None
result = []
# Iterate through each function call in the list
for node in tree.body[0].value.elts:
if isinstance(node, ast.Call):
function_name = node.func.id
function_args = {}
# State variables for parsing
pos = 0
length = len(content)
# Extract keyword arguments
for keyword in node.keywords:
try:
function_args[keyword.arg] = ast.literal_eval(keyword.value)
except ValueError as e:
logger.error(
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
)
raise ValueError(
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
) from e
while pos < length:
# Find function name
name_end = content.find("(", pos)
if name_end == -1:
break
result.append((function_name, function_args))
func_name = content[pos:name_end].strip()
return result
# Find closing parenthesis for this function call
paren_level = 1
args_start = name_end + 1
args_end = args_start
while args_end < length and paren_level > 0:
if content[args_end] == "(":
paren_level += 1
elif content[args_end] == ")":
paren_level -= 1
args_end += 1
if paren_level != 0:
# Unmatched parentheses
return None
# Parse arguments
args_str = content[args_start : args_end - 1].strip()
args_dict = {}
if args_str:
# Split by commas, but respect nested structures
parts = []
part_start = 0
in_quotes = False
quote_char = None
nested_level = 0
for i, char in enumerate(args_str):
if char in ('"', "'") and (i == 0 or args_str[i - 1] != "\\"):
if not in_quotes:
in_quotes = True
quote_char = char
elif char == quote_char:
in_quotes = False
quote_char = None
elif not in_quotes:
if char in ("{", "["):
nested_level += 1
elif char in ("}", "]"):
nested_level -= 1
elif char == "," and nested_level == 0:
parts.append(args_str[part_start:i].strip())
part_start = i + 1
parts.append(args_str[part_start:].strip())
# Process each key=value pair
for part in parts:
if "=" in part:
key, value = part.split("=", 1)
key = key.strip()
value = value.strip()
# Try to convert value to appropriate Python type
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
# String
value = value[1:-1]
elif value.lower() == "true":
value = True
elif value.lower() == "false":
value = False
elif value.lower() == "none":
value = None
elif value.startswith("{") and value.endswith("}"):
# This is a nested dictionary
try:
# Try to parse as JSON
value = json.loads(value.replace("'", '"'))
except json.JSONDecodeError:
# Keep as string if parsing fails
pass
elif value.startswith("[") and value.endswith("]"):
# This is a nested list
try:
# Try to parse as JSON
value = json.loads(value.replace("'", '"'))
except json.JSONDecodeError:
# Keep as string if parsing fails
pass
else:
# Try to convert to number
try:
if "." in value:
value = float(value)
else:
value = int(value)
except ValueError:
# Keep as string if not a valid number
pass
args_dict[key] = value
result.append((func_name, args_dict))
# Move to the next function call
pos = args_end
# Skip the comma between function calls if present
if pos < length and content[pos] == ",":
pos += 1
return result if result else None
class ToolUtils:
@ -150,17 +204,19 @@ class ToolUtils:
return None
elif is_json(message_body):
response = json.loads(message_body)
if ("type" in response and response["type"] == "function") or ("name" in response):
if ("type" in response and response["type"] == "function") or (
"name" in response and "parameters" in response
):
function_name = response["name"]
args = response["parameters"]
return function_name, args
else:
return None
elif is_valid_python_list(message_body):
res = parse_python_list_for_function_calls(message_body)
elif function_calls := parse_llama_tool_call_format(message_body):
# FIXME: Enable multiple tool calls
return res[0]
return function_calls[0]
else:
logger.debug(f"Did not parse tool call from message body: {message_body}")
return None
@staticmethod

View file

@ -70,6 +70,9 @@ class ModelArgs(BaseModel):
attention_chunk_size: Optional[int] = None
rope_theta: float = 500000
use_scaled_rope: bool = False
rope_scaling_factor: Optional[float] = None
rope_high_freq_factor: Optional[float] = None
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
use_qk_norm: bool = False
# Set to True to enable inference-time temperature tuning (useful for very long context)
@ -92,4 +95,14 @@ class ModelArgs(BaseModel):
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
)
assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})"
if self.use_scaled_rope:
# NOTE: ideally these values should have come from params.json. However, we have
# shipped the models everywhere. Only Llama-4-Scout uses scaled rope and needs these
# specific values.
if self.rope_scaling_factor is None:
self.rope_scaling_factor = 16
if self.rope_high_freq_factor is None:
self.rope_high_freq_factor = 1
return self

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
@ -299,9 +300,9 @@ class ChatFormat:
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
)
)
content = ""
return RawMessage(
role="assistant",

View file

@ -233,7 +233,7 @@ class Llama4:
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)

View file

@ -23,37 +23,25 @@ from .ffn import FeedForward
from .moe import MoE
def rmsnorm(x, eps):
def _norm(y):
return y * torch.rsqrt(y.pow(2).mean(-1, keepdim=True) + eps)
return _norm(x.float()).type_as(x)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
return rmsnorm(x, self.eps) * self.weight
class L2Norm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
def apply_scaling(freqs: torch.Tensor, scale_factor: float, high_freq_factor: float):
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
@ -72,11 +60,18 @@ def apply_scaling(freqs: torch.Tensor):
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
def precompute_freqs_cis(
dim: int,
end: int,
theta: float,
use_scaled: bool,
scale_factor: float,
high_freq_factor: float,
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = apply_scaling(freqs, scale_factor, high_freq_factor)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
@ -174,9 +169,7 @@ class Attention(nn.Module):
self.head_dim,
)
).cuda()
self.qk_norm = None
if self.use_qk_norm:
self.qk_norm = L2Norm(args.norm_eps)
self.norm_eps = args.norm_eps
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
@ -220,8 +213,8 @@ class Attention(nn.Module):
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
if self.use_qk_norm:
xq = self.qk_norm(xq)
xk = self.qk_norm(xk)
xq = rmsnorm(xq, self.norm_eps)
xk = rmsnorm(xk, self.norm_eps)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
@ -362,6 +355,8 @@ class Transformer(nn.Module):
args.max_seq_len * 2,
args.rope_theta,
args.use_scaled_rope,
args.rope_scaling_factor,
args.rope_high_freq_factor,
)
vision_args = self.args.vision_args
if vision_args:

View file

@ -91,7 +91,7 @@ def convert_to_quantized_model(
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
def apply_quantization(_, weight):
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
return quantize_int4(weight, output_device=torch.device("cuda"))
else:
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")

View file

@ -56,8 +56,8 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
"<|text_post_train_reserved_special_token_3|>",
"<|text_post_train_reserved_special_token_4|>",
"<|text_post_train_reserved_special_token_5|>",
"<|text_post_train_reserved_special_token_6|>",
"<|text_post_train_reserved_special_token_7|>",
"<|python_start|>",
"<|python_end|>",
"<|finetune_right_pad|>",
] + get_reserved_special_tokens(
"text_post_train", 61, 8

View file

@ -65,7 +65,7 @@ class Int4Weights(
Int4ScaledWeights,
collections.namedtuple(
"Int4Weights",
["weight", "scale", "zero_point", "shape", "activation_scale_ub"],
["weight", "scale", "zero_point", "shape"],
),
):
pass
@ -184,20 +184,13 @@ def quantize_fp8(
@torch.inference_mode()
def quantize_int4(
w: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Int4Weights:
"""Quantize [n, k/2] weight tensor.
Args:
w (Tensor): [n, k/2] input high precision tensor to quantize.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
if w.ndim >= 3:
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
@ -212,7 +205,6 @@ def quantize_int4(
scale=scale.to(output_device),
zero_point=zero_point.to(output_device),
shape=wq.shape,
activation_scale_ub=activation_scale_ub,
)
@ -247,26 +239,18 @@ def load_int4(
w: Tensor,
scale: Tensor,
zero_point: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
) -> Int4Weights:
"""Load INT4 [n, k/2] weight tensor.
Args:
w (Tensor): [n, k/2] input INT4.
fp8_activation_scale_ub (float): Upper bound for activation max.
"""
activation_scale_ub = torch.tensor(
[fp8_activation_scale_ub],
dtype=torch.float,
device=output_device,
)
return Int4Weights(
weight=w.to(torch.int8).to(device=output_device),
scale=scale.to(device=output_device),
zero_point=zero_point.to(device=output_device),
shape=w.shape,
activation_scale_ub=activation_scale_ub,
)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, List, Optional, Protocol
from urllib.parse import urlparse
@ -201,3 +202,12 @@ def remote_provider_spec(
adapter=adapter,
api_dependencies=api_dependencies or [],
)
class HealthStatus(str, Enum):
OK = "OK"
ERROR = "Error"
NOT_IMPLEMENTED = "Not Implemented"
HealthResponse = dict[str, Any]

View file

@ -178,6 +178,8 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("request", request.model_dump_json())
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
await self._initialize_tools(request.toolgroups)
async for chunk in self._run_turn(request, turn_id):
@ -190,6 +192,8 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("session_id", request.session_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", request.turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
await self._initialize_tools()
async for chunk in self._run_turn(request):
@ -498,6 +502,8 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason = None
async with tracing.span("inference") as span:
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,

View file

@ -52,14 +52,17 @@ class MetaReferenceInferenceConfig(BaseModel):
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
**kwargs,
) -> Dict[str, Any]:
return {
"model": model,
"max_seq_len": 4096,
"checkpoint_dir": checkpoint_dir,
"quantization": {
"type": quantization_type,
},
"model_parallel_size": model_parallel_size,
"max_batch_size": max_batch_size,
"max_seq_len": max_seq_len,
}

View file

@ -22,7 +22,7 @@ from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model
from llama_stack.models.llama.sku_types import Model, ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
@ -113,8 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
return get_default_tool_prompt_format(request.model)
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
class Llama4Generator:
class LlamaGenerator:
def __init__(
self,
config: MetaReferenceInferenceConfig,
@ -144,7 +143,8 @@ class Llama4Generator:
else:
quantization_mode = None
self.inner_generator = Llama4.build(
cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
self.inner_generator = cls.build(
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
@ -158,142 +158,55 @@ class Llama4Generator:
def completion(
self,
request: CompletionRequestWithRawContent,
request_batch: List[CompletionRequestWithRawContent],
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content)],
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
logprobs=bool(first_request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
first_request.response_format,
),
):
yield result[0]
yield result
def chat_completion(
self,
request: ChatCompletionRequestWithRawContent,
request_batch: List[ChatCompletionRequestWithRawContent],
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
llm_inputs=[
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
for request in request_batch
],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
logprobs=bool(first_request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
first_request.response_format,
),
):
yield result[0]
class Llama3Generator:
def __init__(
self,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
ckpt_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if config.quantization:
if config.quantization.type == "fp8_mixed":
quantization_mode = QuantizationMode.fp8_mixed
elif config.quantization.type == "int4_mixed":
quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
quantization_mode = None
self.inner_generator = Llama3.build(
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion(
self,
request: CompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
model_inputs=[self.formatter.encode_content(request.content)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
):
yield result[0]
def chat_completion(
self,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
):
yield result[0]
yield result

View file

@ -5,10 +5,10 @@
# the root directory of this source tree.
import asyncio
import logging
import os
from typing import AsyncGenerator, List, Optional, Union
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import (
@ -17,6 +17,8 @@ from llama_stack.apis.common.content_types import (
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
@ -38,8 +40,10 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
ToolPromptFormat,
UserMessage,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@ -54,6 +58,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
@ -61,24 +69,22 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import MetaReferenceInferenceConfig
from .generators import Llama3Generator, Llama4Generator
from .generators import LlamaGenerator
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
log = get_logger(__name__, category="inference")
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator:
return Llama3Generator(config, model_id, llama_model)
def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
return Llama4Generator(config, model_id, llama_model)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model)
class MetaReferenceInferenceImpl(
OpenAICompletionToLlamaStackMixin,
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
@ -133,24 +139,12 @@ class MetaReferenceInferenceImpl(
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if llama_model.model_family in {
ModelFamily.llama3,
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
}:
builder_fn = llama3_builder_fn
elif llama_model.model_family == ModelFamily.llama4:
builder_fn = llama4_builder_fn
else:
raise ValueError(f"Unsupported model family: {llama_model.model_family}")
builder_params = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
builder_fn=builder_fn,
builder_fn=llama_builder_fn,
builder_params=builder_params,
formatter=(
Llama4ChatFormat(Llama4Tokenizer.get_instance())
@ -160,11 +154,24 @@ class MetaReferenceInferenceImpl(
)
self.generator.start()
else:
self.generator = builder_fn(*builder_params)
self.generator = llama_builder_fn(*builder_params)
self.model_id = model_id
self.llama_model = llama_model
log.info("Warming up...")
await self.completion(
model_id=model_id,
content="Hello, world!",
sampling_params=SamplingParams(max_tokens=10),
)
await self.chat_completion(
model_id=model_id,
messages=[UserMessage(content="Hi how are you?")],
sampling_params=SamplingParams(max_tokens=20),
)
log.info("Warmed up!")
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
@ -202,7 +209,43 @@ class MetaReferenceInferenceImpl(
if request.stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
results = await self._nonstream_completion([request])
return results[0]
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content_batch = [
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
]
request_batch = []
for content in content_batch:
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_batch.append(request)
results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
@ -247,37 +290,54 @@ class MetaReferenceInferenceImpl(
for x in impl():
yield x
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
tokens = []
logprobs = []
stop_reason = None
states = [ItemState() for _ in request_batch]
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
results = []
for token_results in self.generator.completion(request_batch):
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
if request.logprobs:
assert len(token_result.logprobs) == 1
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
tokens = tokens[:-1]
content = self.generator.formatter.tokenizer.decode(tokens)
return CompletionResponse(
content=content,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
state.tokens = state.tokens[:-1]
content = self.generator.formatter.tokenizer.decode(state.tokens)
results.append(
CompletionResponse(
content=content,
stop_reason=state.stop_reason,
logprobs=state.logprobs if first_request.logprobs else None,
)
)
return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
@ -312,7 +372,7 @@ class MetaReferenceInferenceImpl(
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
@ -328,44 +388,110 @@ class MetaReferenceInferenceImpl(
if request.stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
results = await self._nonstream_chat_completion([request])
return results[0]
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request_batch = []
for messages in messages_batch:
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: List[ChatCompletionRequest]
) -> List[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
tokens = []
logprobs = []
stop_reason = None
states = [ItemState() for _ in request_batch]
for token_result in self.generator.chat_completion(request):
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="")
for token_results in self.generator.chat_completion(request_batch):
first = token_results[0]
if not first.finished and not first.ignore_token:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, "cyan", end="")
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", "magenta", end="")
tokens.append(token_result.token)
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
if request.logprobs:
assert len(token_result.logprobs) == 1
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
results = []
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
results.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=state.logprobs if first_request.logprobs else None,
)
)
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=logprobs if request.logprobs else None,
)
return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
@ -389,9 +515,26 @@ class MetaReferenceInferenceImpl(
stop_reason = None
ipython = False
for token_result in self.generator.chat_completion(request):
for token_results in self.generator.chat_completion([request]):
token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="")
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", "magenta", end="")
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token)

View file

@ -6,7 +6,7 @@
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Generator
from typing import Any, Callable, Generator, List
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@ -23,13 +23,13 @@ class ModelRunner:
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, req: Any):
if isinstance(req, ChatCompletionRequestWithRawContent):
return self.llama.chat_completion(req)
elif isinstance(req, CompletionRequestWithRawContent):
return self.llama.completion(req)
def __call__(self, task: Any):
if task[0] == "chat_completion":
return self.llama.chat_completion(task[1])
elif task[0] == "completion":
return self.llama.completion(task[1])
else:
raise ValueError(f"Unexpected task type {type(req)}")
raise ValueError(f"Unexpected task type {task[0]}")
def init_model_cb(
@ -82,16 +82,16 @@ class LlamaModelParallelGenerator:
def completion(
self,
request: CompletionRequestWithRawContent,
request_batch: List[CompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("completion", req_obj))
yield from gen
def chat_completion(
self,
request: ChatCompletionRequestWithRawContent,
request_batch: List[ChatCompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("chat_completion", req_obj))
yield from gen

View file

@ -19,7 +19,7 @@ import tempfile
import time
import uuid
from enum import Enum
from typing import Callable, Generator, Literal, Optional, Union
from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
import torch
import zmq
@ -69,12 +69,12 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: GenerationResult
result: List[GenerationResult]
class ExceptionResponse(BaseModel):
@ -331,7 +331,7 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -10,6 +10,7 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.inference import (
CompletionResponse,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
@ -23,6 +24,10 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
)
from .config import SentenceTransformersInferenceConfig
@ -30,6 +35,8 @@ log = logging.getLogger(__name__)
class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
@ -74,3 +81,25 @@ class SentenceTransformersInferenceImpl(
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")

View file

@ -66,8 +66,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionToLlamaStackMixin,
get_stop_reason,
process_chat_completion_stream_response,
)
@ -172,7 +174,12 @@ def _convert_sampling_params(
return vllm_sampling_params
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
class VLLMInferenceImpl(
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
"""
vLLM-based inference model adapter for Llama Stack with support for multiple models.

View file

@ -3,13 +3,14 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
@ -25,9 +26,19 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
RESOURCES_STATS = "resources_stats"
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
class TorchtunePostTrainingImpl:
def __init__(
self,
@ -38,13 +49,27 @@ class TorchtunePostTrainingImpl:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
self._scheduler = Scheduler()
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
self.checkpoints_dict = {}
async def shutdown(self) -> None:
await self._scheduler.shutdown()
async def shutdown(self):
pass
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
metadata=resources_stats,
)
async def supervised_fine_tune(
self,
@ -56,20 +81,11 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob:
if job_uuid in self.jobs:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(timezone.utc),
)
self.jobs[job_uuid] = job_status_response
if isinstance(algorithm_config, LoraFinetuningConfig):
try:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting Lora finetuning")
recipe = LoraFinetuningSingleDevice(
self.config,
job_uuid,
@ -82,26 +98,22 @@ class TorchtunePostTrainingImpl:
self.datasetio_api,
self.datasets_api,
)
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now(timezone.utc)
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
self.checkpoints_dict[job_uuid] = checkpoints
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now(timezone.utc)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
except Exception:
job_status_response.status = JobStatus.failed
raise
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("Lora finetuning completed")
else:
raise NotImplementedError()
return post_training_job
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
self,
@ -114,19 +126,55 @@ class TorchtunePostTrainingImpl:
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(data=[PostTrainingJob(job_uuid=uuid_) for uuid_ in self.jobs])
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
return self.jobs.get(job_uuid, None)
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=checkpoints)
return None
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

View file

@ -38,6 +38,8 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
EfficiencyConfig,
LoraFinetuningConfig,
OptimizerConfig,
QATFinetuningConfig,
@ -89,6 +91,10 @@ class LoraFinetuningSingleDevice:
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized"
assert isinstance(training_config.efficiency_config, EfficiencyConfig), "EfficiencyConfig must be initialized"
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
@ -188,6 +194,7 @@ class LoraFinetuningSingleDevice:
self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.")
assert isinstance(self.training_config.optimizer_config, OptimizerConfig), "OptimizerConfig must be initialized"
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
log.info("Optimizer is initialized.")
@ -195,6 +202,8 @@ class LoraFinetuningSingleDevice:
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
log.info("Loss is initialized.")
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
self._training_sampler, self._training_dataloader = await self._setup_data(
dataset_id=self.training_config.data_config.dataset_id,
tokenizer=self._tokenizer,
@ -452,6 +461,7 @@ class LoraFinetuningSingleDevice:
"""
The core training loop.
"""
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
running_loss: float = 0.0

View file

@ -10,7 +10,6 @@ from typing import Any, Dict, List, Optional
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
Inference,
Message,
UserMessage,
@ -239,16 +238,12 @@ class LlamaGuardShield:
shield_input_message = self.build_text_shield_input(messages)
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
async for chunk in await self.inference_api.chat_completion(
response = await self.inference_api.chat_completion(
model_id=self.model,
messages=[shield_input_message],
stream=True,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.progress and event.delta.type == "text":
content += event.delta.text
stream=False,
)
content = response.completion_message.content
content = content.strip()
return self.get_shield_response(content)

View file

@ -24,7 +24,7 @@ META_REFERENCE_DEPS = [
"zmq",
"lm-format-enforcer",
"sentence-transformers",
"torchao==0.5.0",
"torchao==0.8.0",
"fbgemm-gpu-genai==1.1.2",
]

View file

@ -36,8 +36,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionToLlamaStackMixin,
get_sampling_strategy_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -51,7 +53,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import MODEL_ENTRIES
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self._config = config

View file

@ -34,6 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -49,7 +51,12 @@ from .config import CerebrasImplConfig
from .models import MODEL_ENTRIES
class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
class CerebrasInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(
self,

View file

@ -34,6 +34,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -56,7 +58,12 @@ model_entries = [
]
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
class DatabricksInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=model_entries)
self.config = config

View file

@ -4,9 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from fireworks.client import Fireworks
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -31,14 +32,23 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
@ -81,10 +91,16 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
)
return provider_data.fireworks_api_key
def _get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
async def completion(
self,
model_id: str,
@ -268,3 +284,140 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
# Fireworks always prepends with BOS
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = await prepare_openai_completion_params(
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -4,8 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from openai import AsyncOpenAI
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChoiceDelta,
OpenAIChunkChoice,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAISystemMessageParam,
)
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import (
prepare_openai_completion_params,
)
from .models import MODEL_ENTRIES
@ -21,9 +37,129 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
provider_data_api_key_field="groq_api_key",
)
self.config = config
self._openai_client = None
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
if self._openai_client:
await self._openai_client.close()
self._openai_client = None
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
self._openai_client = AsyncOpenAI(
base_url=f"{self.config.url}/openai/v1",
api_key=self.config.api_key,
)
return self._openai_client
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
# Groq does not support json_schema response format, so we need to convert it to json_object
if response_format and response_format.type == "json_schema":
response_format.type = "json_object"
schema = response_format.json_schema.get("schema", {})
response_format.json_schema = None
json_instructions = f"\nYour response should be a JSON object that matches the following schema: {schema}"
if messages and messages[0].role == "system":
messages[0].content = messages[0].content + json_instructions
else:
messages.insert(0, OpenAISystemMessageParam(content=json_instructions))
# Groq returns a 400 error if tools are provided but none are called
# So, set tool_choice to "required" to attempt to force a call
if tools and (not tool_choice or tool_choice == "auto"):
tool_choice = "required"
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id.replace("groq/", ""),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
# Groq does not support streaming requests that set response_format
fake_stream = False
if stream and response_format:
params["stream"] = False
fake_stream = True
response = await self._get_openai_client().chat.completions.create(**params)
if fake_stream:
chunk_choices = []
for choice in response.choices:
delta = OpenAIChoiceDelta(
content=choice.message.content,
role=choice.message.role,
tool_calls=choice.message.tool_calls,
)
chunk_choice = OpenAIChunkChoice(
delta=delta,
finish_reason=choice.finish_reason,
index=choice.index,
logprobs=None,
)
chunk_choices.append(chunk_choice)
chunk = OpenAIChatCompletionChunk(
id=response.id,
choices=chunk_choices,
object="chat.completion.chunk",
created=response.created,
model=response.model,
)
async def _fake_stream_generator():
yield chunk
return _fake_stream_generator()
else:
return response

View file

@ -39,8 +39,16 @@ MODEL_ENTRIES = [
"groq/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"groq/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]

View file

@ -0,0 +1,85 @@
# NVIDIA Inference Provider for LlamaStack
This provider enables running inference using NVIDIA NIM.
## Features
- Endpoints for completions, chat completions, and embeddings for registered models
## Getting Started
### Prerequisites
- LlamaStack with NVIDIA configuration
- Access to NVIDIA NIM deployment
- NIM for model to use for inference is deployed
### Setup
Build the NVIDIA environment:
```bash
llama stack build --template nvidia --image-type conda
```
### Basic Usage using the LlamaStack Python Client
#### Initialize the client
```python
import os
os.environ["NVIDIA_API_KEY"] = (
"" # Required if using hosted NIM endpoint. If self-hosted, not required.
)
os.environ["NVIDIA_BASE_URL"] = "http://nim.test" # NIM URL
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
```
### Create Completion
```python
response = client.completion(
model_id="meta-llama/Llama-3.1-8b-Instruct",
content="Complete the sentence using one word: Roses are red, violets are :",
stream=False,
sampling_params={
"max_tokens": 50,
},
)
print(f"Response: {response.content}")
```
### Create Chat Completion
```python
response = client.chat_completion(
model_id="meta-llama/Llama-3.1-8b-Instruct",
messages=[
{
"role": "system",
"content": "You must respond to each message with only one word",
},
{
"role": "user",
"content": "Complete the sentence using one word: Roses are red, violets are:",
},
],
stream=False,
sampling_params={
"max_tokens": 50,
},
)
print(f"Response: {response.completion_message.content}")
```
### Create Embeddings
```python
response = client.embeddings(
model_id="meta-llama/Llama-3.1-8b-Instruct", contents=["foo", "bar", "baz"]
)
print(f"Embeddings: {response.embeddings}")
```

View file

@ -48,6 +48,10 @@ MODEL_ENTRIES = [
"meta/llama-3.2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
# NeMo Retriever Text Embedding models -
#
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html

View file

@ -7,7 +7,7 @@
import logging
import warnings
from functools import lru_cache
from typing import AsyncIterator, List, Optional, Union
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
@ -35,6 +35,13 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
@ -42,6 +49,7 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
prepare_openai_completion_params,
)
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
@ -118,6 +126,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
return _get_client_for_base_url(base_url)
async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store:
raise RuntimeError("Model store is not set")
model = await self.model_store.get_model(model_id)
if model is None:
raise ValueError(f"Model {model_id} is unknown")
return model.provider_model_id
async def completion(
self,
model_id: str,
@ -136,7 +152,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# removing this health check as NeMo customizer endpoint health check is returning 404
# await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
provider_model_id = await self._get_provider_model_id(model_id)
request = convert_completion_request(
request=CompletionRequest(
model=provider_model_id,
@ -180,7 +196,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
#
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
model = self.get_provider_model_id(model_id)
provider_model_id = await self._get_provider_model_id(model_id)
extra_body = {}
@ -203,8 +219,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self._get_client(model).embeddings.create(
model=model,
response = await self._get_client(provider_model_id).embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
)
@ -238,10 +254,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
provider_model_id = await self._get_provider_model_id(model_id)
request = await convert_chat_completion_request(
request=ChatCompletionRequest(
model=self.get_provider_model_id(model_id),
model=provider_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
@ -263,3 +279,111 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
else:
# we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0])
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
try:
return await self._get_client(provider_model_id).completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
try:
return await self._get_client(provider_model_id).chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
from typing import Any, AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx
from ollama import AsyncClient
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import (
ImageContentItem,
@ -38,9 +39,20 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -67,7 +79,10 @@ from .models import model_entries
logger = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
class OllamaInferenceAdapter(
Inference,
ModelsProtocolPrivate,
):
def __init__(self, url: str) -> None:
self.register_helper = ModelRegistryHelper(model_entries)
self.url = url
@ -76,10 +91,25 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
def client(self) -> AsyncClient:
return AsyncClient(host=self.url)
@property
def openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama")
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
await self.health()
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Ollama server.
This method is used by initialize() and the Provider API to verify that the service is running
correctly.
Returns:
HealthResponse: A dictionary containing the health status.
"""
try:
await self.client.ps()
return HealthResponse(status=HealthStatus.OK)
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
@ -313,12 +343,149 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
if model.provider_resource_id in available_models_latest:
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
)
return model
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
return model
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")
model_obj = await self._get_model(model)
params = {
k: v
for k, v in {
"model": model_obj.provider_resource_id,
"prompt": prompt,
"best_of": best_of,
"echo": echo,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"logprobs": logprobs,
"max_tokens": max_tokens,
"n": n,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options,
"temperature": temperature,
"top_p": top_p,
"user": user,
}.items()
if v is not None
}
return await self.openai_client.completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model)
params = {
k: v
for k, v in {
"model": model_obj.provider_resource_id,
"messages": messages,
"frequency_penalty": frequency_penalty,
"function_call": function_call,
"functions": functions,
"logit_bias": logit_bias,
"logprobs": logprobs,
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"n": n,
"parallel_tool_calls": parallel_tool_calls,
"presence_penalty": presence_penalty,
"response_format": response_format,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options,
"temperature": temperature,
"tool_choice": tool_choice,
"tools": tools,
"top_logprobs": top_logprobs,
"top_p": top_p,
"user": user,
}.items()
if v is not None
}
return await self.openai_client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack_client import AsyncLlamaStackClient
@ -26,9 +26,17 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from .config import PassthroughImplConfig
@ -201,6 +209,112 @@ class PassthroughInferenceAdapter(Inference):
task_type=task_type,
)
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
return await client.inference.openai_completion(**params)
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await client.inference.openai_chat_completion(**params)
def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
json_params = {}
for key, value in request_params.items():

View file

@ -12,6 +12,8 @@ from llama_stack.apis.inference import * # noqa: F403
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -38,7 +40,12 @@ RUNPOD_SUPPORTED_MODELS = {
}
class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
class RunpodInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
self.config = config

View file

@ -42,6 +42,8 @@ from llama_stack.apis.inference import (
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -52,7 +54,12 @@ from .config import SambaNovaImplConfig
from .models import MODEL_ENTRIES
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
class SambaNovaInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: SambaNovaImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config

View file

@ -40,8 +40,10 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
@ -69,7 +71,12 @@ def build_hf_repo_model_entries():
]
class _HfAdapter(Inference, ModelsProtocolPrivate):
class _HfAdapter(
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
client: AsyncInferenceClient
max_tokens: int
model_id: str

View file

@ -4,8 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from openai import AsyncOpenAI
from together import AsyncTogether
from llama_stack.apis.common.content_types import (
@ -30,12 +31,20 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
@ -60,14 +69,18 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config
self._client = None
self._openai_client = None
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
if self._client:
await self._client.close()
# Together client has no close method, so just set to None
self._client = None
if self._openai_client:
await self._openai_client.close()
self._openai_client = None
async def completion(
self,
@ -110,6 +123,15 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
self._client = AsyncTogether(api_key=together_api_key)
return self._client
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
together_client = self._get_client().client
self._openai_client = AsyncOpenAI(
base_url=together_client.base_url,
api_key=together_client.api_key,
)
return self._openai_client
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
client = self._get_client()
@ -243,3 +265,123 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
)
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# together.ai sometimes adds usage data to the stream, even if include_usage is False
# This causes an unexpected final chunk with empty choices array to be sent
# to clients that may not handle it gracefully.
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
import logging
from typing import Any, AsyncGenerator, List, Optional, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx
from openai import AsyncOpenAI
@ -45,6 +45,12 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
@ -58,6 +64,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
convert_tool_call,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
@ -224,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None
async def initialize(self) -> None:
log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
)
pass
async def shutdown(self) -> None:
pass
@ -242,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def _lazy_initialize_client(self):
if self.client is not None:
return
log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client()
def _create_client(self):
return AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
)
async def completion(
self,
model_id: str,
@ -251,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
@ -280,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
@ -350,9 +368,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
async def register_model(self, model: Model) -> Model:
assert self.client is not None
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
# Changing this may lead to unpredictable behavior.
client = self._create_client() if self.client is None else self.client
model = await self.register_helper.register_model(model)
res = await self.client.models.list()
res = await client.models.list()
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
@ -367,7 +388,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
options["max_tokens"] = self.config.max_tokens
input_dict: dict[str, Any] = {}
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
# Only include the 'tools' param if there is any. It can break things if an empty list is sent to the vLLM.
if isinstance(request, ChatCompletionRequest) and request.tools:
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
if isinstance(request, ChatCompletionRequest):
@ -402,6 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
self._lazy_initialize_client()
assert self.client is not None
model = await self._get_model(model_id)
@ -418,3 +441,133 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
extra_body: Dict[str, Any] = {}
if prompt_logprobs is not None and prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = prompt_logprobs
if guided_choice:
extra_body["guided_choice"] = guided_choice
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
extra_body=extra_body,
)
return await self.client.completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self.client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")

View file

@ -16,7 +16,11 @@ _MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
)
),
build_hf_repo_model_entry(
"meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
]

View file

@ -27,11 +27,12 @@ from .models import _MODEL_ENTRIES
# Map API status to JobStatus enum
STATUS_MAPPING = {
"running": "in_progress",
"completed": "completed",
"failed": "failed",
"cancelled": "cancelled",
"pending": "scheduled",
"running": JobStatus.in_progress.value,
"completed": JobStatus.completed.value,
"failed": JobStatus.failed.value,
"cancelled": JobStatus.cancelled.value,
"pending": JobStatus.scheduled.value,
"unknown": JobStatus.scheduled.value,
}
@ -206,10 +207,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig] = None,
extra_json: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
**kwargs,
) -> NvidiaPostTrainingJob:
"""
Fine-tunes a model on a dataset.

View file

@ -0,0 +1,77 @@
# NVIDIA Safety Provider for LlamaStack
This provider enables safety checks and guardrails for LLM interactions using NVIDIA's NeMo Guardrails service.
## Features
- Run safety checks for messages
## Getting Started
### Prerequisites
- LlamaStack with NVIDIA configuration
- Access to NVIDIA NeMo Guardrails service
- NIM for model to use for safety check is deployed
### Setup
Build the NVIDIA environment:
```bash
llama stack build --template nvidia --image-type conda
```
### Basic Usage using the LlamaStack Python Client
#### Initialize the client
```python
import os
os.environ["NVIDIA_API_KEY"] = "your-api-key"
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test"
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
```
#### Create a safety shield
```python
from llama_stack.apis.safety import Shield
from llama_stack.apis.inference import Message
# Create a safety shield
shield = Shield(
shield_id="your-shield-id",
provider_resource_id="safety-model-id", # The model to use for safety checks
description="Safety checks for content moderation",
)
# Register the shield
await client.safety.register_shield(shield)
```
#### Run safety checks
```python
# Messages to check
messages = [Message(role="user", content="Your message to check")]
# Run safety check
response = await client.safety.run_shield(
shield_id="your-shield-id",
messages=messages,
)
# Check for violations
if response.violation:
print(f"Safety violation detected: {response.violation.user_message}")
print(f"Violation level: {response.violation.violation_level}")
print(f"Metadata: {response.violation.metadata}")
else:
print("No safety violations detected")
```

Some files were not shown because too many files have changed in this diff Show more