mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Merge branch 'main' of https://github.com/meta-llama/llama-stack into register_custom_model
This commit is contained in:
commit
0990f60dad
74 changed files with 4854 additions and 1869 deletions
1
.github/workflows/integration-tests.yml
vendored
1
.github/workflows/integration-tests.yml
vendored
|
@ -6,7 +6,6 @@ on:
|
|||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'distributions/**'
|
||||
- 'llama_stack/**'
|
||||
- 'tests/integration/**'
|
||||
- 'uv.lock'
|
||||
|
|
44
.github/workflows/providers-build.yml
vendored
44
.github/workflows/providers-build.yml
vendored
|
@ -86,15 +86,15 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
|
@ -107,3 +107,41 @@ jobs:
|
|||
- name: Build a single provider
|
||||
run: |
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama
|
||||
|
||||
build-custom-container-distribution:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install LlamaStack
|
||||
run: |
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install -e .
|
||||
|
||||
- name: Build a single provider
|
||||
run: |
|
||||
yq -i '.image_type = "container"' llama_stack/templates/dev/build.yaml
|
||||
yq -i '.image_name = "test"' llama_stack/templates/dev/build.yaml
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/dev/build.yaml
|
||||
|
||||
- name: Inspect the container image entrypoint
|
||||
run: |
|
||||
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
||||
echo "Entrypoint: $entrypoint"
|
||||
if [ "$entrypoint" != "[python -m llama_stack.distribution.server.server --config /app/run.yaml]" ]; then
|
||||
echo "Entrypoint is not correct"
|
||||
exit 1
|
||||
fi
|
||||
|
|
30
.github/workflows/test-external-providers.yml
vendored
30
.github/workflows/test-external-providers.yml
vendored
|
@ -5,10 +5,22 @@ on:
|
|||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'llama_stack/**'
|
||||
- 'tests/integration/**'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/test-external-providers.yml' # This workflow
|
||||
|
||||
jobs:
|
||||
test-external-providers:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
image-type: [venv]
|
||||
# We don't do container yet, it's tricky to install a package from the host into the
|
||||
# container and point 'uv pip install' to the correct path...
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
@ -35,17 +47,25 @@ jobs:
|
|||
uv sync --extra dev --extra test
|
||||
uv pip install -e .
|
||||
|
||||
- name: Install Ollama custom provider
|
||||
- name: Apply image type to config file
|
||||
run: |
|
||||
yq -i '.image_type = "${{ matrix.image-type }}"' tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||
cat tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||
|
||||
- name: Setup directory for Ollama custom provider
|
||||
run: |
|
||||
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
|
||||
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
|
||||
uv pip install tests/external-provider/llama-stack-provider-ollama
|
||||
|
||||
- name: Create provider configuration
|
||||
run: |
|
||||
mkdir -p /tmp/providers.d/remote/inference
|
||||
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
||||
|
||||
- name: Build distro from config file
|
||||
run: |
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||
|
||||
- name: Wait for Ollama to start
|
||||
run: |
|
||||
echo "Waiting for Ollama..."
|
||||
|
@ -62,11 +82,13 @@ jobs:
|
|||
exit 1
|
||||
|
||||
- name: Start Llama Stack server in background
|
||||
if: ${{ matrix.image-type }} == 'venv'
|
||||
env:
|
||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
nohup uv run llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type venv > server.log 2>&1 &
|
||||
source ci-test/bin/activate
|
||||
uv run pip list
|
||||
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
|
|
1
.github/workflows/unit-tests.yml
vendored
1
.github/workflows/unit-tests.yml
vendored
|
@ -6,7 +6,6 @@ on:
|
|||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'distributions/**'
|
||||
- 'llama_stack/**'
|
||||
- 'tests/unit/**'
|
||||
- 'uv.lock'
|
||||
|
|
29
docs/_static/js/detect_theme.js
vendored
29
docs/_static/js/detect_theme.js
vendored
|
@ -1,9 +1,32 @@
|
|||
document.addEventListener("DOMContentLoaded", function () {
|
||||
const prefersDark = window.matchMedia("(prefers-color-scheme: dark)").matches;
|
||||
const htmlElement = document.documentElement;
|
||||
if (prefersDark) {
|
||||
htmlElement.setAttribute("data-theme", "dark");
|
||||
|
||||
// Check if theme is saved in localStorage
|
||||
const savedTheme = localStorage.getItem("sphinx-rtd-theme");
|
||||
|
||||
if (savedTheme) {
|
||||
// Use the saved theme preference
|
||||
htmlElement.setAttribute("data-theme", savedTheme);
|
||||
document.body.classList.toggle("dark", savedTheme === "dark");
|
||||
} else {
|
||||
htmlElement.setAttribute("data-theme", "light");
|
||||
// Fall back to system preference
|
||||
const theme = prefersDark ? "dark" : "light";
|
||||
htmlElement.setAttribute("data-theme", theme);
|
||||
document.body.classList.toggle("dark", theme === "dark");
|
||||
// Save initial preference
|
||||
localStorage.setItem("sphinx-rtd-theme", theme);
|
||||
}
|
||||
|
||||
// Listen for theme changes from the existing toggle
|
||||
const observer = new MutationObserver(function(mutations) {
|
||||
mutations.forEach(function(mutation) {
|
||||
if (mutation.attributeName === "data-theme") {
|
||||
const currentTheme = htmlElement.getAttribute("data-theme");
|
||||
localStorage.setItem("sphinx-rtd-theme", currentTheme);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
observer.observe(htmlElement, { attributes: true });
|
||||
});
|
||||
|
|
22
docs/_static/llama-stack-spec.html
vendored
22
docs/_static/llama-stack-spec.html
vendored
|
@ -5221,17 +5221,25 @@
|
|||
"default": 10
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
"type": "string",
|
||||
"description": "The model identifier to use for the agent"
|
||||
},
|
||||
"instructions": {
|
||||
"type": "string"
|
||||
"type": "string",
|
||||
"description": "The system instructions for the agent"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Optional name for the agent, used in telemetry and identification"
|
||||
},
|
||||
"enable_session_persistence": {
|
||||
"type": "boolean",
|
||||
"default": false
|
||||
"default": false,
|
||||
"description": "Optional flag indicating whether session data has to be persisted"
|
||||
},
|
||||
"response_format": {
|
||||
"$ref": "#/components/schemas/ResponseFormat"
|
||||
"$ref": "#/components/schemas/ResponseFormat",
|
||||
"description": "Optional response format configuration"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -5239,7 +5247,8 @@
|
|||
"model",
|
||||
"instructions"
|
||||
],
|
||||
"title": "AgentConfig"
|
||||
"title": "AgentConfig",
|
||||
"description": "Configuration for an agent."
|
||||
},
|
||||
"AgentTool": {
|
||||
"oneOf": [
|
||||
|
@ -8891,8 +8900,7 @@
|
|||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"role",
|
||||
"content"
|
||||
"role"
|
||||
],
|
||||
"title": "OpenAIAssistantMessageParam",
|
||||
"description": "A message containing the model's (assistant) response in an OpenAI-compatible chat completion request."
|
||||
|
|
12
docs/_static/llama-stack-spec.yaml
vendored
12
docs/_static/llama-stack-spec.yaml
vendored
|
@ -3686,18 +3686,29 @@ components:
|
|||
default: 10
|
||||
model:
|
||||
type: string
|
||||
description: >-
|
||||
The model identifier to use for the agent
|
||||
instructions:
|
||||
type: string
|
||||
description: The system instructions for the agent
|
||||
name:
|
||||
type: string
|
||||
description: >-
|
||||
Optional name for the agent, used in telemetry and identification
|
||||
enable_session_persistence:
|
||||
type: boolean
|
||||
default: false
|
||||
description: >-
|
||||
Optional flag indicating whether session data has to be persisted
|
||||
response_format:
|
||||
$ref: '#/components/schemas/ResponseFormat'
|
||||
description: Optional response format configuration
|
||||
additionalProperties: false
|
||||
required:
|
||||
- model
|
||||
- instructions
|
||||
title: AgentConfig
|
||||
description: Configuration for an agent.
|
||||
AgentTool:
|
||||
oneOf:
|
||||
- type: string
|
||||
|
@ -6097,7 +6108,6 @@ components:
|
|||
additionalProperties: false
|
||||
required:
|
||||
- role
|
||||
- content
|
||||
title: OpenAIAssistantMessageParam
|
||||
description: >-
|
||||
A message containing the model's (assistant) response in an OpenAI-compatible
|
||||
|
|
|
@ -68,7 +68,8 @@ chunks_response = client.vector_io.query(
|
|||
### Using the RAG Tool
|
||||
|
||||
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
|
||||
and automatically chunks them into smaller pieces.
|
||||
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
|
||||
[appendix](#more-ragdocument-examples).
|
||||
|
||||
```python
|
||||
from llama_stack_client import RAGDocument
|
||||
|
@ -178,3 +179,38 @@ for vector_db_id in client.vector_dbs.list():
|
|||
print(f"Unregistering vector database: {vector_db_id.identifier}")
|
||||
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
|
||||
```
|
||||
|
||||
### Appendix
|
||||
|
||||
#### More RAGDocument Examples
|
||||
```python
|
||||
from llama_stack_client import RAGDocument
|
||||
import base64
|
||||
|
||||
RAGDocument(document_id="num-0", content={"uri": "file://path/to/file"})
|
||||
RAGDocument(document_id="num-1", content="plain text")
|
||||
RAGDocument(
|
||||
document_id="num-2",
|
||||
content={
|
||||
"type": "text",
|
||||
"text": "plain text input",
|
||||
}, # for inputs that should be treated as text explicitly
|
||||
)
|
||||
RAGDocument(
|
||||
document_id="num-3",
|
||||
content={
|
||||
"type": "image",
|
||||
"image": {"url": {"uri": "https://mywebsite.com/image.jpg"}},
|
||||
},
|
||||
)
|
||||
B64_ENCODED_IMAGE = base64.b64encode(
|
||||
requests.get(
|
||||
"https://raw.githubusercontent.com/meta-llama/llama-stack/refs/heads/main/docs/_static/llama-stack.png"
|
||||
).content
|
||||
)
|
||||
RAGDocuemnt(
|
||||
document_id="num-4",
|
||||
content={"type": "image", "image": {"data": B64_ENCODED_IMAGE}},
|
||||
)
|
||||
```
|
||||
for more strongly typed interaction use the typed dicts found [here](https://github.com/meta-llama/llama-stack-client-python/blob/38cd91c9e396f2be0bec1ee96a19771582ba6f17/src/llama_stack_client/types/shared_params/document.py).
|
||||
|
|
|
@ -41,7 +41,7 @@ client.toolgroups.register(
|
|||
|
||||
The tool requires an API key which can be provided either in the configuration or through the request header `X-LlamaStack-Provider-Data`. The format of the header is `{"<provider_name>_api_key": <your api key>}`.
|
||||
|
||||
|
||||
> **NOTE:** When using Tavily Search and Bing Search, the inference output will still display "Brave Search." This is because Llama models have been trained with Brave Search as a built-in tool. Tavily and bing is just being used in lieu of Brave search.
|
||||
|
||||
#### Code Interpreter
|
||||
|
||||
|
@ -214,3 +214,69 @@ response = agent.create_turn(
|
|||
session_id=session_id,
|
||||
)
|
||||
```
|
||||
## Simple Example 2: Using an Agent with the Web Search Tool
|
||||
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
|
||||
2. [Optional] Provide the API key directly to the Llama Stack server
|
||||
```bash
|
||||
export TAVILY_SEARCH_API_KEY="your key"
|
||||
```
|
||||
```bash
|
||||
--env TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY}
|
||||
```
|
||||
3. Run the following script.
|
||||
```python
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
client = LlamaStackClient(
|
||||
base_url=f"http://localhost:8321",
|
||||
provider_data={
|
||||
"tavily_search_api_key": "your_TAVILY_SEARCH_API_KEY"
|
||||
}, # Set this from the client side. No need to provide it if it has already been configured on the Llama Stack server.
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
client,
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
instructions=(
|
||||
"You are a web search assistant, must use websearch tool to look up the most current and precise information available. "
|
||||
),
|
||||
tools=["builtin::websearch"],
|
||||
)
|
||||
|
||||
session_id = agent.create_session("websearch-session")
|
||||
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{"role": "user", "content": "How did the USA perform in the last Olympics?"}
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
for log in EventLogger().log(response):
|
||||
log.print()
|
||||
```
|
||||
|
||||
## Simple Example3: Using an Agent with the WolframAlpha Tool
|
||||
1. Start by registering for a WolframAlpha API key at [WolframAlpha Developer Portal](https://developer.wolframalpha.com/access).
|
||||
2. Provide the API key either when starting the Llama Stack server:
|
||||
```bash
|
||||
--env WOLFRAM_ALPHA_API_KEY=${WOLFRAM_ALPHA_API_KEY}
|
||||
```
|
||||
or from the client side:
|
||||
```python
|
||||
client = LlamaStackClient(
|
||||
base_url="http://localhost:8321",
|
||||
provider_data={"wolfram_alpha_api_key": wolfram_api_key},
|
||||
)
|
||||
```
|
||||
3. Configure the tools in the Agent by setting `tools=["builtin::wolfram_alpha"]`.
|
||||
4. Example user query:
|
||||
```python
|
||||
response = agent.create_turn(
|
||||
messages=[{"role": "user", "content": "Solve x^2 + 2x + 1 = 0 using WolframAlpha"}],
|
||||
session_id=session_id,
|
||||
)
|
||||
```
|
||||
```
|
||||
|
|
|
@ -176,7 +176,11 @@ distribution_spec:
|
|||
safety: inline::llama-guard
|
||||
agents: inline::meta-reference
|
||||
telemetry: inline::meta-reference
|
||||
image_name: ollama
|
||||
image_type: conda
|
||||
|
||||
# If some providers are external, you can specify the path to the implementation
|
||||
external_providers_dir: /etc/llama-stack/providers.d
|
||||
```
|
||||
|
||||
```
|
||||
|
@ -184,6 +188,57 @@ llama stack build --config llama_stack/templates/ollama/build.yaml
|
|||
```
|
||||
:::
|
||||
|
||||
:::{tab-item} Building with External Providers
|
||||
|
||||
Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently or use community-provided providers.
|
||||
|
||||
To build a distribution with external providers, you need to:
|
||||
|
||||
1. Configure the `external_providers_dir` in your build configuration file:
|
||||
|
||||
```yaml
|
||||
# Example my-external-stack.yaml with external providers
|
||||
version: '2'
|
||||
distribution_spec:
|
||||
description: Custom distro for CI tests
|
||||
providers:
|
||||
inference:
|
||||
- remote::custom_ollama
|
||||
# Add more providers as needed
|
||||
image_type: container
|
||||
image_name: ci-test
|
||||
# Path to external provider implementations
|
||||
external_providers_dir: /etc/llama-stack/providers.d
|
||||
```
|
||||
|
||||
Here's an example for a custom Ollama provider:
|
||||
|
||||
```yaml
|
||||
adapter:
|
||||
adapter_type: custom_ollama
|
||||
pip_packages:
|
||||
- ollama
|
||||
- aiohttp
|
||||
- llama-stack-provider-ollama # This is the provider package
|
||||
config_class: llama_stack_ollama_provider.config.OllamaImplConfig
|
||||
module: llama_stack_ollama_provider
|
||||
api_dependencies: []
|
||||
optional_api_dependencies: []
|
||||
```
|
||||
|
||||
The `pip_packages` section lists the Python packages required by the provider, as well as the
|
||||
provider package itself. The package must be available on PyPI or can be provided from a local
|
||||
directory or a git repository (git must be installed on the build environment).
|
||||
|
||||
2. Build your distribution using the config file:
|
||||
|
||||
```
|
||||
llama stack build --config my-external-stack.yaml
|
||||
```
|
||||
|
||||
For more information on external providers, including directory structure, provider types, and implementation requirements, see the [External Providers documentation](../providers/external.md).
|
||||
:::
|
||||
|
||||
:::{tab-item} Building Container
|
||||
|
||||
```{admonition} Podman Alternative
|
||||
|
|
|
@ -1,87 +0,0 @@
|
|||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||
# NVIDIA Distribution
|
||||
|
||||
The `llamastack/distribution-nvidia` distribution consists of the following provider configurations.
|
||||
|
||||
| API | Provider(s) |
|
||||
|-----|-------------|
|
||||
| agents | `inline::meta-reference` |
|
||||
| datasetio | `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| inference | `remote::nvidia` |
|
||||
| post_training | `remote::nvidia` |
|
||||
| safety | `remote::nvidia` |
|
||||
| scoring | `inline::basic` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
| tool_runtime | `inline::rag-runtime` |
|
||||
| vector_io | `inline::faiss` |
|
||||
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
|
||||
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
||||
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
||||
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
||||
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
|
||||
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
|
||||
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
||||
- `nvidia/nv-embedqa-e5-v5 `
|
||||
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
||||
- `snowflake/arctic-embed-l `
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
||||
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/).
|
||||
|
||||
|
||||
## Running Llama Stack with NVIDIA
|
||||
|
||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||
|
||||
### Via Docker
|
||||
|
||||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-nvidia \
|
||||
--yaml-config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
```
|
||||
|
||||
### Via Conda
|
||||
|
||||
```bash
|
||||
llama stack build --template nvidia --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
|
@ -22,6 +22,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
|||
The following environment variables can be configured:
|
||||
|
||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
|
||||
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||
|
@ -43,20 +44,91 @@ The following models are available by default:
|
|||
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||
- `meta/llama-3.3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
||||
- `nvidia/nv-embedqa-e5-v5 `
|
||||
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
||||
- `snowflake/arctic-embed-l `
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
## Prerequisites
|
||||
### NVIDIA API Keys
|
||||
|
||||
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/).
|
||||
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
|
||||
|
||||
### Deploy NeMo Microservices Platform
|
||||
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
|
||||
|
||||
## Supported Services
|
||||
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
|
||||
|
||||
### Inference: NVIDIA NIM
|
||||
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
|
||||
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
|
||||
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
|
||||
|
||||
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
|
||||
|
||||
### Datasetio API: NeMo Data Store
|
||||
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
|
||||
|
||||
See the [NVIDIA Datasetio docs](/llama_stack/providers/remote/datasetio/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
### Eval API: NeMo Evaluator
|
||||
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the [NVIDIA Eval docs](/llama_stack/providers/remote/eval/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
### Post-Training API: NeMo Customizer
|
||||
The NeMo Customizer microservice supports fine-tuning models. You can reference [this list of supported models](/llama_stack/providers/remote/post_training/nvidia/models.py) that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the [NVIDIA Post-Training docs](/llama_stack/providers/remote/post_training/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
### Safety API: NeMo Guardrails
|
||||
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the NVIDIA Safety docs for supported features and example usage.
|
||||
|
||||
## Deploying models
|
||||
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
|
||||
|
||||
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
|
||||
```sh
|
||||
# URL to NeMo NIM Proxy service
|
||||
export NEMO_URL="http://nemo.test"
|
||||
|
||||
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
|
||||
-H 'accept: application/json' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"name": "llama-3.2-1b-instruct",
|
||||
"namespace": "meta",
|
||||
"config": {
|
||||
"model": "meta/llama-3.2-1b-instruct",
|
||||
"nim_deployment": {
|
||||
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
|
||||
"image_tag": "1.8.3",
|
||||
"pvc_size": "25Gi",
|
||||
"gpu": 1,
|
||||
"additional_envs": {
|
||||
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
|
||||
|
||||
You can also remove a deployed NIM to free up GPU resources, if needed.
|
||||
```sh
|
||||
export NEMO_URL="http://nemo.test"
|
||||
|
||||
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
|
||||
```
|
||||
|
||||
## Running Llama Stack with NVIDIA
|
||||
|
||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
|
||||
|
||||
### Via Docker
|
||||
|
||||
|
@ -78,9 +150,23 @@ docker run \
|
|||
### Via Conda
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||
llama stack build --template nvidia --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||
llama stack build --template nvidia --image-type venv
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
|
|
@ -44,7 +44,7 @@ The following environment variables can be configured:
|
|||
In the following sections, we'll use AMD, NVIDIA or Intel GPUs to serve as hardware accelerators for the vLLM
|
||||
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
||||
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
||||
that we only use GPUs here for demonstration purposes.
|
||||
that we only use GPUs here for demonstration purposes. Note that if you run into issues, you can include the environment variable `--env VLLM_DEBUG_LOG_API_SERVER_RESPONSE=true` (available in vLLM v0.8.3 and above) in the `docker run` command to enable log response from API server for debugging.
|
||||
|
||||
### Setting up vLLM server on AMD GPU
|
||||
|
||||
|
|
|
@ -50,9 +50,10 @@ Llama Stack supports two types of external providers:
|
|||
|
||||
Here's a list of known external providers that you can use with Llama Stack:
|
||||
|
||||
| Type | Name | Description | Repository |
|
||||
|------|------|-------------|------------|
|
||||
| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||
| Name | Description | API | Type | Repository |
|
||||
|------|-------------|-----|------|------------|
|
||||
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
|
||||
|
||||
### Remote Provider Specification
|
||||
|
||||
|
|
|
@ -225,8 +225,18 @@ class AgentConfigCommon(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgentConfig(AgentConfigCommon):
|
||||
"""Configuration for an agent.
|
||||
|
||||
:param model: The model identifier to use for the agent
|
||||
:param instructions: The system instructions for the agent
|
||||
:param name: Optional name for the agent, used in telemetry and identification
|
||||
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
|
||||
:param response_format: Optional response format configuration
|
||||
"""
|
||||
|
||||
model: str
|
||||
instructions: str
|
||||
name: Optional[str] = None
|
||||
enable_session_persistence: Optional[bool] = False
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
|
|
|
@ -526,9 +526,9 @@ class OpenAIAssistantMessageParam(BaseModel):
|
|||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
content: Optional[OpenAIChatCompletionMessageContent] = None
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = Field(default_factory=list)
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -210,16 +210,9 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
|
||||
cprint(
|
||||
"Please specify --image-name when building a container from a config file",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if args.print_deps_only:
|
||||
print(f"# Dependencies for {args.template or args.config or image_name}")
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
print(f"uv pip install {' '.join(normal_deps)}")
|
||||
for special_dep in special_deps:
|
||||
|
@ -235,10 +228,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
)
|
||||
|
||||
except (Exception, RuntimeError) as exc:
|
||||
import traceback
|
||||
|
||||
cprint(
|
||||
f"Error building stack: {exc}",
|
||||
color="red",
|
||||
)
|
||||
cprint("Stack trace:", color="red")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
if run_config is None:
|
||||
cprint(
|
||||
|
@ -270,9 +267,10 @@ def _generate_run_config(
|
|||
image_name=image_name,
|
||||
apis=apis,
|
||||
providers={},
|
||||
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None,
|
||||
)
|
||||
# build providers dict
|
||||
provider_registry = get_provider_registry()
|
||||
provider_registry = get_provider_registry(build_config)
|
||||
for api in apis:
|
||||
run_config.providers[api] = []
|
||||
provider_types = build_config.distribution_spec.providers[api]
|
||||
|
@ -286,8 +284,22 @@ def _generate_run_config(
|
|||
if p.deprecation_error:
|
||||
raise InvalidProviderError(p.deprecation_error)
|
||||
|
||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||
if hasattr(config_type, "sample_run_config"):
|
||||
try:
|
||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||
except ModuleNotFoundError:
|
||||
# HACK ALERT:
|
||||
# This code executes after building is done, the import cannot work since the
|
||||
# package is either available in the venv or container - not available on the host.
|
||||
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
|
||||
# external
|
||||
cprint(
|
||||
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
||||
color="yellow",
|
||||
)
|
||||
# Set config_type to None to avoid UnboundLocalError
|
||||
config_type = None
|
||||
|
||||
if config_type is not None and hasattr(config_type, "sample_run_config"):
|
||||
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
|
||||
else:
|
||||
config = {}
|
||||
|
@ -305,11 +317,15 @@ def _generate_run_config(
|
|||
to_write = json.loads(run_config.model_dump_json())
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
# this path is only invoked when no template is provided
|
||||
cprint(
|
||||
f"You can now run your stack with `llama stack run {run_config_file}`",
|
||||
color="green",
|
||||
)
|
||||
# Only print this message for non-container builds since it will be displayed before the
|
||||
# container is built
|
||||
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
||||
# makes sense to display this message
|
||||
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
||||
cprint(
|
||||
f"You can now run your stack with `llama stack run {run_config_file}`",
|
||||
color="green",
|
||||
)
|
||||
return run_config_file
|
||||
|
||||
|
||||
|
@ -319,6 +335,7 @@ def _run_stack_build_command_from_build_config(
|
|||
template_name: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
) -> str:
|
||||
image_name = image_name or build_config.image_name
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
if template_name:
|
||||
image_name = f"distribution-{template_name}"
|
||||
|
@ -342,6 +359,13 @@ def _run_stack_build_command_from_build_config(
|
|||
build_file_path = build_dir / f"{image_name}-build.yaml"
|
||||
|
||||
os.makedirs(build_dir, exist_ok=True)
|
||||
run_config_file = None
|
||||
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
||||
# Only do this if we're building a container image and we're not using a template
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
||||
cprint("Generating run.yaml file", color="green")
|
||||
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||
|
||||
with open(build_file_path, "w") as f:
|
||||
to_write = json.loads(build_config.model_dump_json())
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
@ -350,7 +374,8 @@ def _run_stack_build_command_from_build_config(
|
|||
build_config,
|
||||
build_file_path,
|
||||
image_name,
|
||||
template_or_config=template_name or config_path,
|
||||
template_or_config=template_name or config_path or str(build_file_path),
|
||||
run_config=run_config_file,
|
||||
)
|
||||
if return_code != 0:
|
||||
raise RuntimeError(f"Failed to build image {image_name}")
|
||||
|
|
|
@ -7,16 +7,16 @@
|
|||
import importlib.resources
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
||||
from llama_stack.distribution.datatypes import BuildConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.utils.exec import run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.templates.template import DistributionTemplate
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -37,19 +37,24 @@ class ApiInput(BaseModel):
|
|||
|
||||
|
||||
def get_provider_dependencies(
|
||||
config_providers: Dict[str, List[Provider]],
|
||||
config: BuildConfig | DistributionTemplate,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
all_providers = get_provider_registry()
|
||||
# Extract providers based on config type
|
||||
if isinstance(config, DistributionTemplate):
|
||||
providers = config.providers
|
||||
elif isinstance(config, BuildConfig):
|
||||
providers = config.distribution_spec.providers
|
||||
deps = []
|
||||
registry = get_provider_registry(config)
|
||||
|
||||
for api_str, provider_or_providers in config_providers.items():
|
||||
providers_for_api = all_providers[Api(api_str)]
|
||||
for api_str, provider_or_providers in providers.items():
|
||||
providers_for_api = registry[Api(api_str)]
|
||||
|
||||
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
|
||||
|
||||
for provider in providers:
|
||||
# Providers from BuildConfig and RunConfig are subtly different – not great
|
||||
# Providers from BuildConfig and RunConfig are subtly different – not great
|
||||
provider_type = provider if isinstance(provider, str) else provider.provider_type
|
||||
|
||||
if provider_type not in providers_for_api:
|
||||
|
@ -71,8 +76,8 @@ def get_provider_dependencies(
|
|||
return list(set(normal_deps)), list(set(special_deps))
|
||||
|
||||
|
||||
def print_pip_install_help(providers: Dict[str, List[Provider]]):
|
||||
normal_deps, special_deps = get_provider_dependencies(providers)
|
||||
def print_pip_install_help(config: BuildConfig):
|
||||
normal_deps, special_deps = get_provider_dependencies(config)
|
||||
|
||||
cprint(
|
||||
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
||||
|
@ -88,10 +93,11 @@ def build_image(
|
|||
build_file_path: Path,
|
||||
image_name: str,
|
||||
template_or_config: str,
|
||||
run_config: str | None = None,
|
||||
):
|
||||
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
|
||||
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
|
@ -103,6 +109,11 @@ def build_image(
|
|||
container_base,
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
# When building from a config file (not a template), include the run config path in the
|
||||
# build arguments
|
||||
if run_config is not None:
|
||||
args.append(run_config)
|
||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||
args = [
|
||||
|
|
|
@ -19,12 +19,16 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
|||
# mounting is not supported by docker buildx, so we use COPY instead
|
||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||
|
||||
# Path to the run.yaml file in the container
|
||||
RUN_CONFIG_PATH=/app/run.yaml
|
||||
|
||||
BUILD_CONTEXT_DIR=$(pwd)
|
||||
|
||||
if [ "$#" -lt 4 ]; then
|
||||
# This only works for templates
|
||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<run_config>] [<special_pip_deps>]" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
template_or_config="$1"
|
||||
|
@ -35,8 +39,27 @@ container_base="$1"
|
|||
shift
|
||||
pip_dependencies="$1"
|
||||
shift
|
||||
special_pip_deps="${1:-}"
|
||||
|
||||
# Handle optional arguments
|
||||
run_config=""
|
||||
special_pip_deps=""
|
||||
|
||||
# Check if there are more arguments
|
||||
# The logics is becoming cumbersom, we should refactor it if we can do better
|
||||
if [ $# -gt 0 ]; then
|
||||
# Check if the argument ends with .yaml
|
||||
if [[ "$1" == *.yaml ]]; then
|
||||
run_config="$1"
|
||||
shift
|
||||
# If there's another argument after .yaml, it must be special_pip_deps
|
||||
if [ $# -gt 0 ]; then
|
||||
special_pip_deps="$1"
|
||||
fi
|
||||
else
|
||||
# If it's not .yaml, it must be special_pip_deps
|
||||
special_pip_deps="$1"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
|
@ -72,9 +95,13 @@ if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
|||
FROM $container_base
|
||||
WORKDIR /app
|
||||
|
||||
RUN dnf -y update && dnf install -y iputils net-tools wget \
|
||||
# We install the Python 3.11 dev headers and build tools so that any
|
||||
# C‑extension wheels (e.g. polyleven, faiss‑cpu) can compile successfully.
|
||||
|
||||
RUN dnf -y update && dnf install -y iputils git net-tools wget \
|
||||
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
||||
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
||||
python3.11-setuptools python3.11-devel gcc make && \
|
||||
ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
RUN pip install uv
|
||||
|
@ -86,7 +113,7 @@ WORKDIR /app
|
|||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
iputils-ping net-tools iproute2 dnsutils telnet \
|
||||
curl wget telnet \
|
||||
curl wget telnet git\
|
||||
procps psmisc lsof \
|
||||
traceroute \
|
||||
bubblewrap \
|
||||
|
@ -115,6 +142,45 @@ EOF
|
|||
done
|
||||
fi
|
||||
|
||||
# Function to get Python command
|
||||
get_python_cmd() {
|
||||
if is_command_available python; then
|
||||
echo "python"
|
||||
elif is_command_available python3; then
|
||||
echo "python3"
|
||||
else
|
||||
echo "Error: Neither python nor python3 is installed. Please install Python to continue." >&2
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
if [ -n "$run_config" ]; then
|
||||
# Copy the run config to the build context since it's an absolute path
|
||||
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
add_to_container << EOF
|
||||
COPY run.yaml $RUN_CONFIG_PATH
|
||||
EOF
|
||||
|
||||
# Parse the run.yaml configuration to identify external provider directories
|
||||
# If external providers are specified, copy their directory to the container
|
||||
# and update the configuration to reference the new container path
|
||||
python_cmd=$(get_python_cmd)
|
||||
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
|
||||
if [ -n "$external_providers_dir" ]; then
|
||||
echo "Copying external providers directory: $external_providers_dir"
|
||||
add_to_container << EOF
|
||||
COPY $external_providers_dir /app/providers.d
|
||||
EOF
|
||||
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d
|
||||
if [ "$(uname)" = "Darwin" ]; then
|
||||
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
|
||||
else
|
||||
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
stack_mount="/app/llama-stack-source"
|
||||
client_mount="/app/llama-stack-client-source"
|
||||
|
||||
|
@ -174,15 +240,16 @@ fi
|
|||
RUN pip uninstall -y uv
|
||||
EOF
|
||||
|
||||
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
||||
if [[ "$template_or_config" != *.yaml ]]; then
|
||||
# If a run config is provided, we use the --config flag
|
||||
if [[ -n "$run_config" ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--config", "$RUN_CONFIG_PATH"]
|
||||
EOF
|
||||
# If a template is provided (not a yaml file), we use the --template flag
|
||||
elif [[ "$template_or_config" != *.yaml ]]; then
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"]
|
||||
EOF
|
||||
else
|
||||
add_to_container << EOF
|
||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
|
||||
EOF
|
||||
fi
|
||||
|
||||
# Add other require item commands genearic to all containers
|
||||
|
@ -254,9 +321,10 @@ $CONTAINER_BINARY build \
|
|||
"${CLI_ARGS[@]}" \
|
||||
-t "$image_tag" \
|
||||
-f "$TEMP_DIR/Containerfile" \
|
||||
"."
|
||||
"$BUILD_CONTEXT_DIR"
|
||||
|
||||
# clean up tmp/configs
|
||||
rm -f "$BUILD_CONTEXT_DIR/run.yaml"
|
||||
set +x
|
||||
|
||||
echo "Success!"
|
||||
|
|
|
@ -326,3 +326,12 @@ class BuildConfig(BaseModel):
|
|||
default="conda",
|
||||
description="Type of package to build (conda | container | venv)",
|
||||
)
|
||||
image_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Name of the distribution to build",
|
||||
)
|
||||
external_providers_dir: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||
"pip_packages MUST contain the provider package name.",
|
||||
)
|
||||
|
|
|
@ -12,7 +12,6 @@ from typing import Any, Dict, List
|
|||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
|
@ -97,7 +96,9 @@ def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_nam
|
|||
return spec
|
||||
|
||||
|
||||
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
def get_provider_registry(
|
||||
config=None,
|
||||
) -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
"""Get the provider registry, optionally including external providers.
|
||||
|
||||
This function loads both built-in providers and external providers from YAML files.
|
||||
|
@ -122,7 +123,7 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
|
|||
llama-guard.yaml
|
||||
|
||||
Args:
|
||||
config: Optional StackRunConfig containing the external providers directory path
|
||||
config: Optional object containing the external providers directory path
|
||||
|
||||
Returns:
|
||||
A dictionary mapping APIs to their available providers
|
||||
|
@ -142,7 +143,8 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
|
|||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {name}: {e}")
|
||||
|
||||
if config and config.external_providers_dir:
|
||||
# Check if config has the external_providers_dir attribute
|
||||
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
|
|
|
@ -8,6 +8,11 @@ import asyncio
|
|||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import Field, TypeAdapter
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
|
@ -526,7 +531,7 @@ class InferenceRouter(Inference):
|
|||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
|
@ -558,6 +563,16 @@ class InferenceRouter(Inference):
|
|||
if model_obj.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
||||
|
||||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
if tool_choice:
|
||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
||||
if tools is None:
|
||||
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||
if tools:
|
||||
for tool in tools:
|
||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
messages=messages,
|
||||
|
|
|
@ -22,6 +22,7 @@ from fastapi import Body, FastAPI, HTTPException, Request
|
|||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
@ -92,7 +93,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|||
|
||||
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
||||
if isinstance(exc, ValidationError):
|
||||
exc = RequestValidationError(exc.raw_errors)
|
||||
exc = RequestValidationError(exc.errors())
|
||||
|
||||
if isinstance(exc, RequestValidationError):
|
||||
return HTTPException(
|
||||
|
@ -110,6 +111,8 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
|||
)
|
||||
elif isinstance(exc, ValueError):
|
||||
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
|
||||
elif isinstance(exc, BadRequestError):
|
||||
return HTTPException(status_code=400, detail=str(exc))
|
||||
elif isinstance(exc, PermissionError):
|
||||
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
||||
elif isinstance(exc, TimeoutError):
|
||||
|
@ -162,9 +165,10 @@ async def maybe_await(value):
|
|||
return value
|
||||
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
async def sse_generator(event_gen_coroutine):
|
||||
event_gen = await event_gen_coroutine
|
||||
try:
|
||||
async for item in await event_gen:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
|
|
|
@ -24,6 +24,13 @@ def rag_chat_page():
|
|||
def should_disable_input():
|
||||
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
|
||||
|
||||
def log_message(message):
|
||||
with st.chat_message(message["role"]):
|
||||
if "tool_output" in message and message["tool_output"]:
|
||||
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
|
||||
st.write(message["tool_output"])
|
||||
st.markdown(message["content"])
|
||||
|
||||
with st.sidebar:
|
||||
# File/Directory Upload Section
|
||||
st.subheader("Upload Documents", divider=True)
|
||||
|
@ -146,8 +153,7 @@ def rag_chat_page():
|
|||
|
||||
# Display chat history
|
||||
for message in st.session_state.displayed_messages:
|
||||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
log_message(message)
|
||||
|
||||
if temperature > 0.0:
|
||||
strategy = {
|
||||
|
@ -201,7 +207,7 @@ def rag_chat_page():
|
|||
|
||||
# Display assistant response
|
||||
with st.chat_message("assistant"):
|
||||
retrieval_message_placeholder = st.empty()
|
||||
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
|
@ -209,14 +215,16 @@ def rag_chat_page():
|
|||
log.print()
|
||||
if log.role == "tool_execution":
|
||||
retrieval_response += log.content.replace("====", "").strip()
|
||||
retrieval_message_placeholder.info(retrieval_response)
|
||||
retrieval_message_placeholder.write(retrieval_response)
|
||||
else:
|
||||
full_response += log.content
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
st.session_state.displayed_messages.append({"role": "assistant", "content": full_response})
|
||||
st.session_state.displayed_messages.append(
|
||||
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
|
||||
)
|
||||
|
||||
def direct_process_prompt(prompt):
|
||||
# Add the system prompt in the beginning of the conversation
|
||||
|
@ -230,15 +238,14 @@ def rag_chat_page():
|
|||
prompt_context = rag_response.content
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
with st.expander(label="Retrieval Output", expanded=False):
|
||||
st.write(prompt_context)
|
||||
|
||||
retrieval_message_placeholder = st.empty()
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
|
||||
# Display the retrieved content
|
||||
retrieval_response += str(prompt_context)
|
||||
retrieval_message_placeholder.info(retrieval_response)
|
||||
|
||||
# Construct the extended prompt
|
||||
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
|
||||
|
||||
|
|
|
@ -29,17 +29,39 @@ def tool_chat_page():
|
|||
st.cache_resource.clear()
|
||||
|
||||
with st.sidebar:
|
||||
st.title("Configuration")
|
||||
st.subheader("Model")
|
||||
model = st.selectbox(label="models", options=model_list, on_change=reset_agent)
|
||||
model = st.selectbox(label="Model", options=model_list, on_change=reset_agent, label_visibility="collapsed")
|
||||
|
||||
st.subheader("Available ToolGroups")
|
||||
|
||||
st.subheader("Builtin Tools")
|
||||
toolgroup_selection = st.pills(
|
||||
label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent
|
||||
label="Built-in tools",
|
||||
options=builtin_tools_list,
|
||||
selection_mode="multi",
|
||||
on_change=reset_agent,
|
||||
format_func=lambda tool: "".join(tool.split("::")[1:]),
|
||||
help="List of built-in tools from your llama stack server.",
|
||||
)
|
||||
|
||||
st.subheader("MCP Servers")
|
||||
if "builtin::rag" in toolgroup_selection:
|
||||
vector_dbs = llama_stack_api.client.vector_dbs.list() or []
|
||||
if not vector_dbs:
|
||||
st.info("No vector databases available for selection.")
|
||||
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
||||
selected_vector_dbs = st.multiselect(
|
||||
label="Select Document Collections to use in RAG queries",
|
||||
options=vector_dbs,
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
mcp_selection = st.pills(
|
||||
label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent
|
||||
label="MCP Servers",
|
||||
options=mcp_tools_list,
|
||||
selection_mode="multi",
|
||||
on_change=reset_agent,
|
||||
format_func=lambda tool: "".join(tool.split("::")[1:]),
|
||||
help="List of MCP servers registered to your llama stack server.",
|
||||
)
|
||||
|
||||
toolgroup_selection.extend(mcp_selection)
|
||||
|
@ -53,10 +75,10 @@ def tool_chat_page():
|
|||
]
|
||||
)
|
||||
|
||||
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
|
||||
st.markdown(f"Active Tools: 🛠 {len(active_tool_list)}", help="List of currently active tools.")
|
||||
st.json(active_tool_list)
|
||||
|
||||
st.subheader("Chat Configurations")
|
||||
st.subheader("Agent Configurations")
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
|
@ -67,6 +89,16 @@ def tool_chat_page():
|
|||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
for i, tool_name in enumerate(toolgroup_selection):
|
||||
if tool_name == "builtin::rag":
|
||||
tool_dict = dict(
|
||||
name="builtin::rag",
|
||||
args={
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
},
|
||||
)
|
||||
toolgroup_selection[i] = tool_dict
|
||||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
|
@ -112,7 +144,11 @@ def tool_chat_page():
|
|||
yield response.event.payload.delta.text
|
||||
if response.event.payload.event_type == "step_complete":
|
||||
if response.event.payload.step_details.step_type == "tool_execution":
|
||||
yield " 🛠 "
|
||||
if response.event.payload.step_details.tool_calls:
|
||||
tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name)
|
||||
yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n'
|
||||
else:
|
||||
yield "No tool_calls present in step_details"
|
||||
else:
|
||||
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
@ -299,6 +300,7 @@ class ChatFormat:
|
|||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
arguments_json=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -178,6 +178,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("request", request.model_dump_json())
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
|
@ -190,6 +192,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
await self._initialize_tools()
|
||||
async for chunk in self._run_turn(request):
|
||||
|
@ -498,6 +502,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stop_reason = None
|
||||
|
||||
async with tracing.span("inference") as span:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
|
|
|
@ -515,7 +515,8 @@ class MetaReferenceInferenceImpl(
|
|||
stop_reason = None
|
||||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(request):
|
||||
for token_results in self.generator.chat_completion([request]):
|
||||
token_result = token_results[0]
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||
cprint(token_result.text, "cyan", end="")
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
|
|
|
@ -33,6 +33,7 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
|
@ -153,6 +154,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
)
|
||||
)
|
||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
picked.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
|
||||
)
|
||||
)
|
||||
|
||||
return RAGQueryResult(
|
||||
content=picked,
|
||||
|
|
|
@ -362,6 +362,39 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
|
@ -387,11 +420,4 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
user=user,
|
||||
)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
|
||||
|
||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
||||
|
|
85
llama_stack/providers/remote/inference/nvidia/NVIDIA.md
Normal file
85
llama_stack/providers/remote/inference/nvidia/NVIDIA.md
Normal file
|
@ -0,0 +1,85 @@
|
|||
# NVIDIA Inference Provider for LlamaStack
|
||||
|
||||
This provider enables running inference using NVIDIA NIM.
|
||||
|
||||
## Features
|
||||
- Endpoints for completions, chat completions, and embeddings for registered models
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- LlamaStack with NVIDIA configuration
|
||||
- Access to NVIDIA NIM deployment
|
||||
- NIM for model to use for inference is deployed
|
||||
|
||||
### Setup
|
||||
|
||||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --template nvidia --image-type conda
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
||||
#### Initialize the client
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
os.environ["NVIDIA_API_KEY"] = (
|
||||
"" # Required if using hosted NIM endpoint. If self-hosted, not required.
|
||||
)
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nim.test" # NIM URL
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.initialize()
|
||||
```
|
||||
|
||||
### Create Completion
|
||||
|
||||
```python
|
||||
response = client.completion(
|
||||
model_id="meta-llama/Llama-3.1-8b-Instruct",
|
||||
content="Complete the sentence using one word: Roses are red, violets are :",
|
||||
stream=False,
|
||||
sampling_params={
|
||||
"max_tokens": 50,
|
||||
},
|
||||
)
|
||||
print(f"Response: {response.content}")
|
||||
```
|
||||
|
||||
### Create Chat Completion
|
||||
|
||||
```python
|
||||
response = client.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8b-Instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You must respond to each message with only one word",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Complete the sentence using one word: Roses are red, violets are:",
|
||||
},
|
||||
],
|
||||
stream=False,
|
||||
sampling_params={
|
||||
"max_tokens": 50,
|
||||
},
|
||||
)
|
||||
print(f"Response: {response.completion_message.content}")
|
||||
```
|
||||
|
||||
### Create Embeddings
|
||||
```python
|
||||
response = client.embeddings(
|
||||
model_id="meta-llama/Llama-3.1-8b-Instruct", contents=["foo", "bar", "baz"]
|
||||
)
|
||||
print(f"Embeddings: {response.embeddings}")
|
||||
```
|
|
@ -48,6 +48,10 @@ MODEL_ENTRIES = [
|
|||
"meta/llama-3.2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# NeMo Retriever Text Embedding models -
|
||||
#
|
||||
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
|
|
|
@ -129,6 +129,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
base_url = special_model_urls[provider_model_id]
|
||||
return _get_client_for_base_url(base_url)
|
||||
|
||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||
if not self.model_store:
|
||||
raise RuntimeError("Model store is not set")
|
||||
model = await self.model_store.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {model_id} is unknown")
|
||||
return model.provider_model_id
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -147,7 +155,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
# removing this health check as NeMo customizer endpoint health check is returning 404
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
provider_model_id = await self._get_provider_model_id(model_id)
|
||||
request = convert_completion_request(
|
||||
request=CompletionRequest(
|
||||
model=provider_model_id,
|
||||
|
@ -191,7 +199,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
#
|
||||
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
|
||||
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||
model = self.get_provider_model_id(model_id)
|
||||
provider_model_id = await self._get_provider_model_id(model_id)
|
||||
|
||||
extra_body = {}
|
||||
|
||||
|
@ -214,8 +222,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self._get_client(model).embeddings.create(
|
||||
model=model,
|
||||
response = await self._get_client(provider_model_id).embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
@ -249,10 +257,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
provider_model_id = await self._get_provider_model_id(model_id)
|
||||
request = await convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=self.get_provider_model_id(model_id),
|
||||
model=provider_model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
@ -297,7 +305,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
provider_model_id = self.get_provider_model_id(model)
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=provider_model_id,
|
||||
|
@ -350,7 +358,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
provider_model_id = self.get_provider_model_id(model)
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=provider_model_id,
|
||||
|
|
|
@ -76,8 +76,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
# Together client has no close method, so just set to None
|
||||
self._client = None
|
||||
if self._openai_client:
|
||||
await self._openai_client.close()
|
||||
self._openai_client = None
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
@ -359,7 +362,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", True):
|
||||
if params.get("stream", False):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
|
||||
|
|
|
@ -231,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
||||
self.client = AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||
)
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
@ -249,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
def _lazy_initialize_client(self):
|
||||
if self.client is not None:
|
||||
return
|
||||
|
||||
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||
self.client = self._create_client()
|
||||
|
||||
def _create_client(self):
|
||||
return AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -258,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
self._lazy_initialize_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
|
@ -287,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
self._lazy_initialize_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
|
@ -357,9 +368,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
assert self.client is not None
|
||||
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
|
||||
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
|
||||
# Changing this may lead to unpredictable behavior.
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
model = await self.register_helper.register_model(model)
|
||||
res = await self.client.models.list()
|
||||
res = await client.models.list()
|
||||
available_models = [m.id async for m in res]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
|
@ -410,6 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
|
@ -449,6 +464,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
|
||||
extra_body: Dict[str, Any] = {}
|
||||
|
@ -505,6 +521,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
|
|
|
@ -16,7 +16,11 @@ _MODEL_ENTRIES = [
|
|||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
)
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -27,11 +27,12 @@ from .models import _MODEL_ENTRIES
|
|||
|
||||
# Map API status to JobStatus enum
|
||||
STATUS_MAPPING = {
|
||||
"running": "in_progress",
|
||||
"completed": "completed",
|
||||
"failed": "failed",
|
||||
"cancelled": "cancelled",
|
||||
"pending": "scheduled",
|
||||
"running": JobStatus.in_progress.value,
|
||||
"completed": JobStatus.completed.value,
|
||||
"failed": JobStatus.failed.value,
|
||||
"cancelled": JobStatus.cancelled.value,
|
||||
"pending": JobStatus.scheduled.value,
|
||||
"unknown": JobStatus.scheduled.value,
|
||||
}
|
||||
|
||||
|
||||
|
|
77
llama_stack/providers/remote/safety/nvidia/README.md
Normal file
77
llama_stack/providers/remote/safety/nvidia/README.md
Normal file
|
@ -0,0 +1,77 @@
|
|||
# NVIDIA Safety Provider for LlamaStack
|
||||
|
||||
This provider enables safety checks and guardrails for LLM interactions using NVIDIA's NeMo Guardrails service.
|
||||
|
||||
## Features
|
||||
|
||||
- Run safety checks for messages
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- LlamaStack with NVIDIA configuration
|
||||
- Access to NVIDIA NeMo Guardrails service
|
||||
- NIM for model to use for safety check is deployed
|
||||
|
||||
### Setup
|
||||
|
||||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --template nvidia --image-type conda
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
||||
#### Initialize the client
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test"
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.initialize()
|
||||
```
|
||||
|
||||
#### Create a safety shield
|
||||
|
||||
```python
|
||||
from llama_stack.apis.safety import Shield
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
# Create a safety shield
|
||||
shield = Shield(
|
||||
shield_id="your-shield-id",
|
||||
provider_resource_id="safety-model-id", # The model to use for safety checks
|
||||
description="Safety checks for content moderation",
|
||||
)
|
||||
|
||||
# Register the shield
|
||||
await client.safety.register_shield(shield)
|
||||
```
|
||||
|
||||
#### Run safety checks
|
||||
|
||||
```python
|
||||
# Messages to check
|
||||
messages = [Message(role="user", content="Your message to check")]
|
||||
|
||||
# Run safety check
|
||||
response = await client.safety.run_shield(
|
||||
shield_id="your-shield-id",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Check for violations
|
||||
if response.violation:
|
||||
print(f"Safety violation detected: {response.violation.user_message}")
|
||||
print(f"Violation level: {response.violation.violation_level}")
|
||||
print(f"Metadata: {response.violation.metadata}")
|
||||
else:
|
||||
print("No safety violations detected")
|
||||
```
|
|
@ -8,7 +8,17 @@ import logging
|
|||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import (
|
||||
|
@ -78,6 +88,7 @@ from llama_stack.apis.common.content_types import (
|
|||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
_URLOrData,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
@ -93,6 +104,7 @@ from llama_stack.apis.inference import (
|
|||
SamplingParams,
|
||||
SystemMessage,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolResponseMessage,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
|
@ -103,7 +115,6 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ToolConfig,
|
||||
)
|
||||
|
@ -612,13 +623,10 @@ async def convert_message_to_openai_dict_new(
|
|||
)
|
||||
for tool in message.tool_calls
|
||||
]
|
||||
params = {}
|
||||
if tool_calls:
|
||||
params = {"tool_calls": tool_calls}
|
||||
out = OpenAIChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=await _convert_message_content(message.content),
|
||||
**params,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
elif isinstance(message, ToolResponseMessage):
|
||||
out = OpenAIChatCompletionToolMessage(
|
||||
|
@ -695,7 +703,10 @@ def to_openai_param_type(param_type: str) -> dict:
|
|||
if param_type.startswith("list[") and param_type.endswith("]"):
|
||||
inner_type = param_type[5:-1]
|
||||
if inner_type in basic_types:
|
||||
return {"type": "array", "items": {"type": basic_types.get(inner_type, inner_type)}}
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {"type": basic_types.get(inner_type, inner_type)},
|
||||
}
|
||||
|
||||
return {"type": param_type}
|
||||
|
||||
|
@ -815,6 +826,10 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
|||
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
|
||||
tool_config = ToolConfig()
|
||||
if tool_choice:
|
||||
try:
|
||||
tool_choice = ToolChoice(tool_choice)
|
||||
except ValueError:
|
||||
pass
|
||||
tool_config.tool_choice = tool_choice
|
||||
return tool_config
|
||||
|
||||
|
@ -849,7 +864,9 @@ def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None)
|
|||
return lls_tools
|
||||
|
||||
|
||||
def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None):
|
||||
def _convert_openai_request_response_format(
|
||||
response_format: OpenAIResponseFormatParam = None,
|
||||
):
|
||||
if not response_format:
|
||||
return None
|
||||
# response_format can be a dict or a pydantic model
|
||||
|
@ -957,38 +974,50 @@ def _convert_openai_sampling_params(
|
|||
return sampling_params
|
||||
|
||||
|
||||
def _convert_openai_request_messages(messages: List[OpenAIMessageParam]):
|
||||
# Llama Stack messages and OpenAI messages are similar, but not identical.
|
||||
lls_messages = []
|
||||
def openai_messages_to_messages(
|
||||
messages: List[OpenAIChatCompletionMessage],
|
||||
) -> List[Message]:
|
||||
"""
|
||||
Convert a list of OpenAIChatCompletionMessage into a list of Message.
|
||||
"""
|
||||
converted_messages = []
|
||||
for message in messages:
|
||||
lls_message = dict(message)
|
||||
if message.role == "system":
|
||||
converted_message = SystemMessage(content=message.content)
|
||||
elif message.role == "user":
|
||||
converted_message = UserMessage(content=openai_content_to_content(message.content))
|
||||
elif message.role == "assistant":
|
||||
converted_message = CompletionMessage(
|
||||
content=message.content,
|
||||
tool_calls=_convert_openai_tool_calls(message.tool_calls),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
elif message.role == "tool":
|
||||
converted_message = ToolResponseMessage(
|
||||
role="tool",
|
||||
call_id=message.tool_call_id,
|
||||
content=openai_content_to_content(message.content),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown role {message.role}")
|
||||
converted_messages.append(converted_message)
|
||||
return converted_messages
|
||||
|
||||
# Llama Stack expects `call_id` but OpenAI uses `tool_call_id`
|
||||
tool_call_id = lls_message.pop("tool_call_id", None)
|
||||
if tool_call_id:
|
||||
lls_message["call_id"] = tool_call_id
|
||||
|
||||
content = lls_message.get("content", None)
|
||||
if isinstance(content, list):
|
||||
lls_content = []
|
||||
for item in content:
|
||||
# items can either by pydantic models or dicts here...
|
||||
item = dict(item)
|
||||
if item.get("type", "") == "image_url":
|
||||
lls_item = ImageContentItem(
|
||||
type="image",
|
||||
image=URL(uri=item.get("image_url", {}).get("url", "")),
|
||||
)
|
||||
elif item.get("type", "") == "text":
|
||||
lls_item = TextContentItem(
|
||||
type="text",
|
||||
text=item.get("text", ""),
|
||||
)
|
||||
lls_content.append(lls_item)
|
||||
lls_message["content"] = lls_content
|
||||
lls_messages.append(lls_message)
|
||||
|
||||
return lls_messages
|
||||
def openai_content_to_content(content: Union[str, Iterable[OpenAIChatCompletionContentPartParam]]):
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
return [openai_content_to_content(c) for c in content]
|
||||
elif hasattr(content, "type"):
|
||||
if content.type == "text":
|
||||
return TextContentItem(type="text", text=content.text)
|
||||
elif content.type == "image_url":
|
||||
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {content.type}")
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {content}")
|
||||
|
||||
|
||||
def convert_openai_chat_completion_choice(
|
||||
|
@ -1313,7 +1342,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
messages = _convert_openai_request_messages(messages)
|
||||
messages = openai_messages_to_messages(messages)
|
||||
response_format = _convert_openai_request_response_format(response_format)
|
||||
sampling_params = _convert_openai_sampling_params(
|
||||
max_tokens=max_tokens,
|
||||
|
@ -1321,7 +1350,10 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
top_p=top_p,
|
||||
)
|
||||
tool_config = _convert_openai_request_tool_config(tool_choice)
|
||||
|
||||
tools = _convert_openai_request_tools(tools)
|
||||
if tool_config.tool_choice == ToolChoice.none:
|
||||
tools = []
|
||||
|
||||
outstanding_responses = []
|
||||
# "n" is the number of completions to generate per prompt
|
||||
|
@ -1346,7 +1378,9 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
)
|
||||
|
||||
async def _process_stream_response(
|
||||
self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]]
|
||||
self,
|
||||
model: str,
|
||||
outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
||||
):
|
||||
id = f"chatcmpl-{uuid.uuid4()}"
|
||||
for outstanding_response in outstanding_responses:
|
||||
|
@ -1369,11 +1403,31 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
elif isinstance(event.delta, ToolCallDelta):
|
||||
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_call = event.delta.tool_call
|
||||
|
||||
# First chunk includes full structure
|
||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id=tool_call.call_id,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name=tool_call.tool_name, arguments=tool_call.arguments_json
|
||||
name=tool_call.tool_name,
|
||||
arguments="",
|
||||
),
|
||||
)
|
||||
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||
],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
# arguments
|
||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||
index=0,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
arguments=tool_call.arguments_json,
|
||||
),
|
||||
)
|
||||
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||
|
|
|
@ -25,14 +25,84 @@ The following models are available by default:
|
|||
{% endif %}
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
## Prerequisites
|
||||
### NVIDIA API Keys
|
||||
|
||||
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/).
|
||||
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
|
||||
|
||||
### Deploy NeMo Microservices Platform
|
||||
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
|
||||
|
||||
## Supported Services
|
||||
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
|
||||
|
||||
### Inference: NVIDIA NIM
|
||||
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
|
||||
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
|
||||
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
|
||||
|
||||
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
|
||||
|
||||
### Datasetio API: NeMo Data Store
|
||||
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
|
||||
|
||||
See the [NVIDIA Datasetio docs](/llama_stack/providers/remote/datasetio/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
### Eval API: NeMo Evaluator
|
||||
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the [NVIDIA Eval docs](/llama_stack/providers/remote/eval/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
### Post-Training API: NeMo Customizer
|
||||
The NeMo Customizer microservice supports fine-tuning models. You can reference [this list of supported models](/llama_stack/providers/remote/post_training/nvidia/models.py) that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the [NVIDIA Post-Training docs](/llama_stack/providers/remote/post_training/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
### Safety API: NeMo Guardrails
|
||||
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the NVIDIA Safety docs for supported features and example usage.
|
||||
|
||||
## Deploying models
|
||||
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
|
||||
|
||||
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
|
||||
```sh
|
||||
# URL to NeMo NIM Proxy service
|
||||
export NEMO_URL="http://nemo.test"
|
||||
|
||||
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
|
||||
-H 'accept: application/json' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"name": "llama-3.2-1b-instruct",
|
||||
"namespace": "meta",
|
||||
"config": {
|
||||
"model": "meta/llama-3.2-1b-instruct",
|
||||
"nim_deployment": {
|
||||
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
|
||||
"image_tag": "1.8.3",
|
||||
"pvc_size": "25Gi",
|
||||
"gpu": 1,
|
||||
"additional_envs": {
|
||||
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
|
||||
|
||||
You can also remove a deployed NIM to free up GPU resources, if needed.
|
||||
```sh
|
||||
export NEMO_URL="http://nemo.test"
|
||||
|
||||
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
|
||||
```
|
||||
|
||||
## Running Llama Stack with NVIDIA
|
||||
|
||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
|
||||
|
||||
### Via Docker
|
||||
|
||||
|
@ -54,9 +124,23 @@ docker run \
|
|||
### Via Conda
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||
llama stack build --template nvidia --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||
|
||||
```bash
|
||||
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||
llama stack build --template nvidia --image-type venv
|
||||
llama stack run ./run.yaml \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
|
|
@ -59,7 +59,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
default_models = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name="nvidia",
|
||||
distro_type="remote_hosted",
|
||||
distro_type="self_hosted",
|
||||
description="Use NVIDIA NIM for running LLM inference and safety",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
|
|
|
@ -174,6 +174,16 @@ models:
|
|||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.3-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 2048
|
||||
context_length: 8192
|
||||
|
|
|
@ -31,7 +31,7 @@ The following environment variables can be configured:
|
|||
In the following sections, we'll use AMD, NVIDIA or Intel GPUs to serve as hardware accelerators for the vLLM
|
||||
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
||||
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
||||
that we only use GPUs here for demonstration purposes.
|
||||
that we only use GPUs here for demonstration purposes. Note that if you run into issues, you can include the environment variable `--env VLLM_DEBUG_LOG_API_SERVER_RESPONSE=true` (available in vLLM v0.8.3 and above) in the `docker run` command to enable log response from API server for debugging.
|
||||
|
||||
### Setting up vLLM server on AMD GPU
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ dev = [
|
|||
"pytest-asyncio",
|
||||
"pytest-cov",
|
||||
"pytest-html",
|
||||
"pytest-json-report",
|
||||
"nbval", # For notebook testing
|
||||
"black",
|
||||
"ruff",
|
||||
|
@ -57,7 +58,16 @@ dev = [
|
|||
"ruamel.yaml", # needed for openapi generator
|
||||
]
|
||||
# These are the dependencies required for running unit tests.
|
||||
unit = ["sqlite-vec", "openai", "aiosqlite", "aiohttp", "pypdf", "chardet", "qdrant-client"]
|
||||
unit = [
|
||||
"sqlite-vec",
|
||||
"openai",
|
||||
"aiosqlite",
|
||||
"aiohttp",
|
||||
"pypdf",
|
||||
"chardet",
|
||||
"qdrant-client",
|
||||
"opentelemetry-exporter-otlp-proto-http"
|
||||
]
|
||||
# These are the core dependencies required for running integration tests. They are shared across all
|
||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
|
||||
|
|
|
@ -98,7 +98,7 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[
|
|||
|
||||
if template_func := getattr(module, "get_distribution_template", None):
|
||||
template = template_func()
|
||||
normal_deps, special_deps = get_provider_dependencies(template.providers)
|
||||
normal_deps, special_deps = get_provider_dependencies(template)
|
||||
# Combine all dependencies in order: normal deps, special deps, server deps
|
||||
all_deps = sorted(set(normal_deps + SERVER_DEPENDENCIES)) + sorted(set(special_deps))
|
||||
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
version: '2'
|
||||
distribution_spec:
|
||||
description: Custom distro for CI tests
|
||||
providers:
|
||||
inference:
|
||||
- remote::custom_ollama
|
||||
image_type: container
|
||||
image_name: ci-test
|
||||
external_providers_dir: /tmp/providers.d
|
|
@ -1,6 +1,6 @@
|
|||
adapter:
|
||||
adapter_type: custom_ollama
|
||||
pip_packages: ["ollama", "aiohttp"]
|
||||
pip_packages: ["ollama", "aiohttp", "tests/external-provider/llama-stack-provider-ollama"]
|
||||
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
|
||||
module: llama_stack_provider_ollama
|
||||
api_dependencies: []
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
version: '2'
|
||||
image_name: ollama
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- datasetio
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
|
@ -24,19 +20,6 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -44,14 +27,6 @@ providers:
|
|||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
|
@ -67,17 +42,6 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
config: {}
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
@ -115,6 +115,70 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
|||
assert "I can't" in logs_str
|
||||
|
||||
|
||||
def test_agent_name(llama_stack_client, text_model_id):
|
||||
agent_name = f"test-agent-{uuid4()}"
|
||||
|
||||
try:
|
||||
agent = Agent(
|
||||
llama_stack_client,
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
name=agent_name,
|
||||
)
|
||||
except TypeError:
|
||||
agent = Agent(
|
||||
llama_stack_client,
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
)
|
||||
return
|
||||
|
||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||
|
||||
agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Give me a sentence that contains the word: hello",
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
all_spans = []
|
||||
for span in llama_stack_client.telemetry.query_spans(
|
||||
attribute_filters=[
|
||||
{"key": "session_id", "op": "eq", "value": session_id},
|
||||
],
|
||||
attributes_to_return=["input", "output", "agent_name", "agent_id", "session_id"],
|
||||
):
|
||||
all_spans.append(span.attributes)
|
||||
|
||||
agent_name_spans = []
|
||||
for span in llama_stack_client.telemetry.query_spans(
|
||||
attribute_filters=[],
|
||||
attributes_to_return=["agent_name"],
|
||||
):
|
||||
if "agent_name" in span.attributes:
|
||||
agent_name_spans.append(span.attributes)
|
||||
|
||||
agent_logs = []
|
||||
for span in llama_stack_client.telemetry.query_spans(
|
||||
attribute_filters=[
|
||||
{"key": "agent_name", "op": "eq", "value": agent_name},
|
||||
],
|
||||
attributes_to_return=["input", "output", "agent_name"],
|
||||
):
|
||||
if "output" in span.attributes and span.attributes["output"] != "no shields":
|
||||
agent_logs.append(span.attributes)
|
||||
|
||||
assert len(agent_logs) == 1
|
||||
assert agent_logs[0]["agent_name"] == agent_name
|
||||
assert "Give me a sentence that contains the word: hello" in agent_logs[0]["input"]
|
||||
assert "hello" in agent_logs[0]["output"].lower()
|
||||
|
||||
|
||||
def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
||||
common_params = dict(
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
|
|
|
@ -31,6 +31,7 @@ def data_url_from_file(file_path: str) -> str:
|
|||
return data_url
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky. Couldn't find 'llamastack/simpleqa' on the Hugging Face Hub")
|
||||
@pytest.mark.parametrize(
|
||||
"purpose, source, provider_id, limit",
|
||||
[
|
||||
|
|
40
tests/unit/distribution/test_build_path.py
Normal file
40
tests/unit/distribution/test_build_path.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.cli.stack._build import (
|
||||
_run_stack_build_command_from_build_config,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
|
||||
|
||||
def test_container_build_passes_path(monkeypatch, tmp_path):
|
||||
called_with = {}
|
||||
|
||||
def spy_build_image(cfg, build_file_path, image_name, template_or_config, run_config=None):
|
||||
called_with["path"] = template_or_config
|
||||
called_with["run_config"] = run_config
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(
|
||||
"llama_stack.cli.stack._build.build_image",
|
||||
spy_build_image,
|
||||
raising=True,
|
||||
)
|
||||
|
||||
cfg = BuildConfig(
|
||||
image_type=LlamaStackImageType.CONTAINER.value,
|
||||
distribution_spec=DistributionSpec(providers={}, description=""),
|
||||
)
|
||||
|
||||
_run_stack_build_command_from_build_config(cfg, image_name="dummy")
|
||||
|
||||
assert "path" in called_with
|
||||
assert isinstance(called_with["path"], str)
|
||||
assert Path(called_with["path"]).exists()
|
||||
assert called_with["run_config"] is None
|
|
@ -216,35 +216,48 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
)
|
||||
|
||||
def test_get_training_job_status(self):
|
||||
self.mock_make_request.return_value = {
|
||||
"created_at": "2024-12-09T04:06:28.580220",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"status": "completed",
|
||||
"steps_completed": 1210,
|
||||
"epochs_completed": 2,
|
||||
"percentage_done": 100.0,
|
||||
"best_epoch": 2,
|
||||
"train_loss": 1.718016266822815,
|
||||
"val_loss": 1.8661999702453613,
|
||||
}
|
||||
customizer_status_to_job_status = [
|
||||
("running", "in_progress"),
|
||||
("completed", "completed"),
|
||||
("failed", "failed"),
|
||||
("cancelled", "cancelled"),
|
||||
("pending", "scheduled"),
|
||||
("unknown", "scheduled"),
|
||||
]
|
||||
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
for customizer_status, expected_status in customizer_status_to_job_status:
|
||||
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
|
||||
self.mock_make_request.return_value = {
|
||||
"created_at": "2024-12-09T04:06:28.580220",
|
||||
"updated_at": "2024-12-09T04:21:19.852832",
|
||||
"status": customizer_status,
|
||||
"steps_completed": 1210,
|
||||
"epochs_completed": 2,
|
||||
"percentage_done": 100.0,
|
||||
"best_epoch": 2,
|
||||
"train_loss": 1.718016266822815,
|
||||
"val_loss": 1.8661999702453613,
|
||||
}
|
||||
|
||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||
assert status.status.value == "completed"
|
||||
assert status.steps_completed == 1210
|
||||
assert status.epochs_completed == 2
|
||||
assert status.percentage_done == 100.0
|
||||
assert status.best_epoch == 2
|
||||
assert status.train_loss == 1.718016266822815
|
||||
assert status.val_loss == 1.8661999702453613
|
||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
||||
|
||||
self.mock_make_request.assert_called_once()
|
||||
self._assert_request(
|
||||
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
|
||||
)
|
||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||
assert status.status.value == expected_status
|
||||
assert status.steps_completed == 1210
|
||||
assert status.epochs_completed == 2
|
||||
assert status.percentage_done == 100.0
|
||||
assert status.best_epoch == 2
|
||||
assert status.train_loss == 1.718016266822815
|
||||
assert status.val_loss == 1.8661999702453613
|
||||
|
||||
self._assert_request(
|
||||
self.mock_make_request,
|
||||
"GET",
|
||||
f"/v1/customization/jobs/{job_id}/status",
|
||||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
def test_get_training_jobs(self):
|
||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||
|
|
55
tests/unit/server/test_sse.py
Normal file
55
tests/unit/server/test_sse.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.server.server import create_sse_event, sse_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_generator_basic():
|
||||
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
||||
async def async_event_gen():
|
||||
async def event_gen():
|
||||
yield "Test event 1"
|
||||
yield "Test event 2"
|
||||
|
||||
return event_gen()
|
||||
|
||||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
# Test that the events are streamed correctly
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
assert len(seen_events) == 2
|
||||
assert seen_events[0] == create_sse_event("Test event 1")
|
||||
assert seen_events[1] == create_sse_event("Test event 2")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_generator_client_disconnected():
|
||||
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
||||
async def async_event_gen():
|
||||
async def event_gen():
|
||||
yield "Test event 1"
|
||||
# Simulate a client disconnect before emitting event 2
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
return event_gen()
|
||||
|
||||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
# Start reading the events, ensuring this doesn't raise an exception
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
assert len(seen_events) == 1
|
||||
assert seen_events[0] == create_sse_event("Test event 1")
|
|
@ -8,29 +8,44 @@ This framework allows you to run the same set of verification tests against diff
|
|||
|
||||
## Features
|
||||
|
||||
The verification suite currently tests:
|
||||
The verification suite currently tests the following in both streaming and non-streaming modes:
|
||||
|
||||
- Basic chat completions (streaming and non-streaming)
|
||||
- Basic chat completions
|
||||
- Image input capabilities
|
||||
- Structured JSON output formatting
|
||||
- Tool calling functionality
|
||||
|
||||
## Report
|
||||
|
||||
The lastest report can be found at [REPORT.md](REPORT.md).
|
||||
|
||||
To update the report, ensure you have the API keys set,
|
||||
```bash
|
||||
export OPENAI_API_KEY=<your_openai_api_key>
|
||||
export FIREWORKS_API_KEY=<your_fireworks_api_key>
|
||||
export TOGETHER_API_KEY=<your_together_api_key>
|
||||
```
|
||||
then run
|
||||
```bash
|
||||
uv run --with-editable ".[dev]" python tests/verifications/generate_report.py --run-tests
|
||||
```
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the verification tests, use pytest with the following parameters:
|
||||
|
||||
```bash
|
||||
cd llama-stack
|
||||
pytest tests/verifications/openai --provider=<provider-name>
|
||||
pytest tests/verifications/openai_api --provider=<provider-name>
|
||||
```
|
||||
|
||||
Example:
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest tests/verifications/openai --provider=together
|
||||
pytest tests/verifications/openai_api --provider=together
|
||||
|
||||
# Only run tests with Llama 4 models
|
||||
pytest tests/verifications/openai --provider=together -k 'Llama-4'
|
||||
pytest tests/verifications/openai_api --provider=together -k 'Llama-4'
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
@ -41,23 +56,22 @@ pytest tests/verifications/openai --provider=together -k 'Llama-4'
|
|||
|
||||
## Supported Providers
|
||||
|
||||
The verification suite currently supports:
|
||||
- OpenAI
|
||||
- Fireworks
|
||||
- Together
|
||||
- Groq
|
||||
- Cerebras
|
||||
The verification suite supports any provider with an OpenAI compatible endpoint.
|
||||
|
||||
See `tests/verifications/conf/` for the list of supported providers.
|
||||
|
||||
To run on a new provider, simply add a new yaml file to the `conf/` directory with the provider config. See `tests/verifications/conf/together.yaml` for an example.
|
||||
|
||||
## Adding New Test Cases
|
||||
|
||||
To add new test cases, create appropriate JSON files in the `openai/fixtures/test_cases/` directory following the existing patterns.
|
||||
To add new test cases, create appropriate JSON files in the `openai_api/fixtures/test_cases/` directory following the existing patterns.
|
||||
|
||||
|
||||
## Structure
|
||||
|
||||
- `__init__.py` - Marks the directory as a Python package
|
||||
- `conftest.py` - Global pytest configuration and fixtures
|
||||
- `openai/` - Tests specific to OpenAI-compatible APIs
|
||||
- `conf/` - Provider-specific configuration files
|
||||
- `openai_api/` - Tests specific to OpenAI-compatible APIs
|
||||
- `fixtures/` - Test fixtures and utilities
|
||||
- `fixtures.py` - Provider-specific fixtures
|
||||
- `load.py` - Utilities for loading test cases
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Test Results Report
|
||||
|
||||
*Generated on: 2025-04-14 18:11:37*
|
||||
*Generated on: 2025-04-17 12:42:33*
|
||||
|
||||
*This report was generated by running `python tests/verifications/generate_report.py`*
|
||||
|
||||
|
@ -15,22 +15,74 @@
|
|||
|
||||
| Provider | Pass Rate | Tests Passed | Total Tests |
|
||||
| --- | --- | --- | --- |
|
||||
| Together | 48.7% | 37 | 76 |
|
||||
| Fireworks | 47.4% | 36 | 76 |
|
||||
| Openai | 100.0% | 52 | 52 |
|
||||
| Meta_reference | 100.0% | 28 | 28 |
|
||||
| Together | 50.0% | 40 | 80 |
|
||||
| Fireworks | 50.0% | 40 | 80 |
|
||||
| Openai | 100.0% | 56 | 56 |
|
||||
|
||||
|
||||
|
||||
## Meta_reference
|
||||
|
||||
*Tests run on: 2025-04-17 12:37:11*
|
||||
|
||||
```bash
|
||||
# Run all tests for this provider:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=meta_reference -v
|
||||
|
||||
# Example: Run only the 'stream=False' case of test_chat_multi_turn_multiple_images:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=meta_reference -k "test_chat_multi_turn_multiple_images and stream=False"
|
||||
```
|
||||
|
||||
|
||||
**Model Key (Meta_reference)**
|
||||
|
||||
| Display Name | Full Model ID |
|
||||
| --- | --- |
|
||||
| Llama-4-Scout-Instruct | `meta-llama/Llama-4-Scout-17B-16E-Instruct` |
|
||||
|
||||
|
||||
| Test | Llama-4-Scout-Instruct |
|
||||
| --- | --- |
|
||||
| test_chat_multi_turn_multiple_images (stream=False) | ✅ |
|
||||
| test_chat_multi_turn_multiple_images (stream=True) | ✅ |
|
||||
| test_chat_non_streaming_basic (earth) | ✅ |
|
||||
| test_chat_non_streaming_basic (saturn) | ✅ |
|
||||
| test_chat_non_streaming_image | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ |
|
||||
| test_chat_non_streaming_structured_output (calendar) | ✅ |
|
||||
| test_chat_non_streaming_structured_output (math) | ✅ |
|
||||
| test_chat_non_streaming_tool_calling | ✅ |
|
||||
| test_chat_non_streaming_tool_choice_none | ✅ |
|
||||
| test_chat_non_streaming_tool_choice_required | ✅ |
|
||||
| test_chat_streaming_basic (earth) | ✅ |
|
||||
| test_chat_streaming_basic (saturn) | ✅ |
|
||||
| test_chat_streaming_image | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (add_product_tool) | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (text_then_weather_tool) | ✅ |
|
||||
| test_chat_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ |
|
||||
| test_chat_streaming_structured_output (calendar) | ✅ |
|
||||
| test_chat_streaming_structured_output (math) | ✅ |
|
||||
| test_chat_streaming_tool_calling | ✅ |
|
||||
| test_chat_streaming_tool_choice_none | ✅ |
|
||||
| test_chat_streaming_tool_choice_required | ✅ |
|
||||
|
||||
## Together
|
||||
|
||||
*Tests run on: 2025-04-14 18:08:14*
|
||||
*Tests run on: 2025-04-17 12:27:45*
|
||||
|
||||
```bash
|
||||
# Run all tests for this provider:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=together -v
|
||||
|
||||
# Example: Run only the 'earth' case of test_chat_non_streaming_basic:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=together -k "test_chat_non_streaming_basic and earth"
|
||||
# Example: Run only the 'stream=False' case of test_chat_multi_turn_multiple_images:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=together -k "test_chat_multi_turn_multiple_images and stream=False"
|
||||
```
|
||||
|
||||
|
||||
|
@ -45,11 +97,13 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=togethe
|
|||
|
||||
| Test | Llama-3.3-70B-Instruct | Llama-4-Maverick-Instruct | Llama-4-Scout-Instruct |
|
||||
| --- | --- | --- | --- |
|
||||
| test_chat_multi_turn_multiple_images (stream=False) | ⚪ | ✅ | ✅ |
|
||||
| test_chat_multi_turn_multiple_images (stream=True) | ⚪ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_basic (earth) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image | ⚪ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (add_product_tool) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ❌ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (compare_monthly_expense_tool) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (get_then_create_event_tool) | ✅ | ❌ | ✅ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (text_then_weather_tool) | ❌ | ❌ | ❌ |
|
||||
| test_chat_non_streaming_multi_turn_tool_calling (weather_tool_then_text) | ✅ | ✅ | ✅ |
|
||||
|
@ -74,14 +128,14 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=togethe
|
|||
|
||||
## Fireworks
|
||||
|
||||
*Tests run on: 2025-04-14 18:04:06*
|
||||
*Tests run on: 2025-04-17 12:29:53*
|
||||
|
||||
```bash
|
||||
# Run all tests for this provider:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=fireworks -v
|
||||
|
||||
# Example: Run only the 'earth' case of test_chat_non_streaming_basic:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=fireworks -k "test_chat_non_streaming_basic and earth"
|
||||
# Example: Run only the 'stream=False' case of test_chat_multi_turn_multiple_images:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=fireworks -k "test_chat_multi_turn_multiple_images and stream=False"
|
||||
```
|
||||
|
||||
|
||||
|
@ -96,6 +150,8 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=firewor
|
|||
|
||||
| Test | Llama-3.3-70B-Instruct | Llama-4-Maverick-Instruct | Llama-4-Scout-Instruct |
|
||||
| --- | --- | --- | --- |
|
||||
| test_chat_multi_turn_multiple_images (stream=False) | ⚪ | ✅ | ✅ |
|
||||
| test_chat_multi_turn_multiple_images (stream=True) | ⚪ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (earth) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image | ⚪ | ✅ | ✅ |
|
||||
|
@ -125,14 +181,14 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=firewor
|
|||
|
||||
## Openai
|
||||
|
||||
*Tests run on: 2025-04-14 18:09:51*
|
||||
*Tests run on: 2025-04-17 12:34:08*
|
||||
|
||||
```bash
|
||||
# Run all tests for this provider:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai -v
|
||||
|
||||
# Example: Run only the 'earth' case of test_chat_non_streaming_basic:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai -k "test_chat_non_streaming_basic and earth"
|
||||
# Example: Run only the 'stream=False' case of test_chat_multi_turn_multiple_images:
|
||||
pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai -k "test_chat_multi_turn_multiple_images and stream=False"
|
||||
```
|
||||
|
||||
|
||||
|
@ -146,6 +202,8 @@ pytest tests/verifications/openai_api/test_chat_completion.py --provider=openai
|
|||
|
||||
| Test | gpt-4o | gpt-4o-mini |
|
||||
| --- | --- | --- |
|
||||
| test_chat_multi_turn_multiple_images (stream=False) | ✅ | ✅ |
|
||||
| test_chat_multi_turn_multiple_images (stream=True) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (earth) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_basic (saturn) | ✅ | ✅ |
|
||||
| test_chat_non_streaming_image | ✅ | ✅ |
|
||||
|
|
|
@ -8,3 +8,4 @@ test_exclusions:
|
|||
llama-3.3-70b:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
|
|
|
@ -12,3 +12,4 @@ test_exclusions:
|
|||
fireworks/llama-v3p3-70b-instruct:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
|
|
|
@ -12,3 +12,4 @@ test_exclusions:
|
|||
accounts/fireworks/models/llama-v3p3-70b-instruct:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
|
|
|
@ -12,3 +12,4 @@ test_exclusions:
|
|||
groq/llama-3.3-70b-versatile:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
|
|
|
@ -12,3 +12,4 @@ test_exclusions:
|
|||
llama-3.3-70b-versatile:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
|
|
8
tests/verifications/conf/meta_reference.yaml
Normal file
8
tests/verifications/conf/meta_reference.yaml
Normal file
|
@ -0,0 +1,8 @@
|
|||
# LLAMA_STACK_PORT=5002 llama stack run meta-reference-gpu --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct --env INFERENCE_CHECKPOINT_DIR=<path_to_ckpt>
|
||||
base_url: http://localhost:5002/v1/openai/v1
|
||||
api_key_var: foo
|
||||
models:
|
||||
- meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
model_display_names:
|
||||
meta-llama/Llama-4-Scout-17B-16E-Instruct: Llama-4-Scout-Instruct
|
||||
test_exclusions: {}
|
|
@ -12,3 +12,4 @@ test_exclusions:
|
|||
together/meta-llama/Llama-3.3-70B-Instruct-Turbo:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
|
|
|
@ -12,3 +12,4 @@ test_exclusions:
|
|||
meta-llama/Llama-3.3-70B-Instruct-Turbo:
|
||||
- test_chat_non_streaming_image
|
||||
- test_chat_streaming_image
|
||||
- test_chat_multi_turn_multiple_images
|
||||
|
|
|
@ -3,14 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "pytest-json-report",
|
||||
# "pyyaml",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
Test Report Generator
|
||||
|
||||
|
@ -67,16 +59,11 @@ RESULTS_DIR.mkdir(exist_ok=True)
|
|||
# Maximum number of test result files to keep per provider
|
||||
MAX_RESULTS_PER_PROVIDER = 1
|
||||
|
||||
PROVIDER_ORDER = [
|
||||
DEFAULT_PROVIDERS = [
|
||||
"meta_reference",
|
||||
"together",
|
||||
"fireworks",
|
||||
"groq",
|
||||
"cerebras",
|
||||
"openai",
|
||||
"together-llama-stack",
|
||||
"fireworks-llama-stack",
|
||||
"groq-llama-stack",
|
||||
"openai-llama-stack",
|
||||
]
|
||||
|
||||
VERIFICATION_CONFIG = _load_all_verification_configs()
|
||||
|
@ -142,6 +129,14 @@ def run_tests(provider, keyword=None):
|
|||
return None
|
||||
|
||||
|
||||
def run_multiple_tests(providers_to_run: list[str], keyword: str | None):
|
||||
"""Runs tests for a list of providers."""
|
||||
print(f"Running tests for providers: {', '.join(providers_to_run)}")
|
||||
for provider in providers_to_run:
|
||||
run_tests(provider.strip(), keyword=keyword)
|
||||
print("Finished running tests.")
|
||||
|
||||
|
||||
def parse_results(
|
||||
result_file,
|
||||
) -> Tuple[DefaultDict[str, DefaultDict[str, Dict[str, bool]]], DefaultDict[str, Set[str]], Set[str], str]:
|
||||
|
@ -250,20 +245,6 @@ def parse_results(
|
|||
return parsed_results, providers_in_file, tests_in_file, run_timestamp_str
|
||||
|
||||
|
||||
def get_all_result_files_by_provider():
|
||||
"""Get all test result files, keyed by provider."""
|
||||
provider_results = {}
|
||||
|
||||
result_files = list(RESULTS_DIR.glob("*.json"))
|
||||
|
||||
for file in result_files:
|
||||
provider = file.stem
|
||||
if provider:
|
||||
provider_results[provider] = file
|
||||
|
||||
return provider_results
|
||||
|
||||
|
||||
def generate_report(
|
||||
results_dict: Dict[str, Any],
|
||||
providers: Dict[str, Set[str]],
|
||||
|
@ -276,6 +257,7 @@ def generate_report(
|
|||
Args:
|
||||
results_dict: Aggregated results [provider][model][test_name] -> status.
|
||||
providers: Dict of all providers and their models {provider: {models}}.
|
||||
The order of keys in this dict determines the report order.
|
||||
all_tests: Set of all test names found.
|
||||
provider_timestamps: Dict of provider to timestamp when tests were run
|
||||
output_file: Optional path to save the report.
|
||||
|
@ -353,22 +335,17 @@ def generate_report(
|
|||
passed_tests += 1
|
||||
provider_totals[provider] = (provider_passed, provider_total)
|
||||
|
||||
# Add summary table (use passed-in providers dict)
|
||||
# Add summary table (use the order from the providers dict keys)
|
||||
report.append("| Provider | Pass Rate | Tests Passed | Total Tests |")
|
||||
report.append("| --- | --- | --- | --- |")
|
||||
for provider in [p for p in PROVIDER_ORDER if p in providers]: # Check against keys of passed-in dict
|
||||
passed, total = provider_totals.get(provider, (0, 0))
|
||||
pass_rate = f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
|
||||
report.append(f"| {provider.capitalize()} | {pass_rate} | {passed} | {total} |")
|
||||
for provider in [p for p in providers if p not in PROVIDER_ORDER]: # Check against keys of passed-in dict
|
||||
# Iterate through providers in the order they appear in the input dict
|
||||
for provider in providers_sorted.keys():
|
||||
passed, total = provider_totals.get(provider, (0, 0))
|
||||
pass_rate = f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
|
||||
report.append(f"| {provider.capitalize()} | {pass_rate} | {passed} | {total} |")
|
||||
report.append("\n")
|
||||
|
||||
for provider in sorted(
|
||||
providers_sorted.keys(), key=lambda p: (PROVIDER_ORDER.index(p) if p in PROVIDER_ORDER else float("inf"), p)
|
||||
):
|
||||
for provider in providers_sorted.keys():
|
||||
provider_models = providers_sorted[provider] # Use sorted models
|
||||
if not provider_models:
|
||||
continue
|
||||
|
@ -461,60 +438,62 @@ def main():
|
|||
"--providers",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Specify providers to test (comma-separated or space-separated, default: all)",
|
||||
help="Specify providers to include/test (comma-separated or space-separated, default: uses DEFAULT_PROVIDERS)",
|
||||
)
|
||||
parser.add_argument("--output", type=str, help="Output file location (default: tests/verifications/REPORT.md)")
|
||||
parser.add_argument("--k", type=str, help="Keyword expression to filter tests (passed to pytest -k)")
|
||||
args = parser.parse_args()
|
||||
|
||||
all_results = {}
|
||||
# Initialize collections to aggregate results in main
|
||||
aggregated_providers = defaultdict(set)
|
||||
final_providers_order = {} # Dictionary to store results, preserving processing order
|
||||
aggregated_tests = set()
|
||||
provider_timestamps = {}
|
||||
|
||||
if args.run_tests:
|
||||
# Get list of available providers from command line or use detected providers
|
||||
if args.providers:
|
||||
# Handle both comma-separated and space-separated lists
|
||||
test_providers = []
|
||||
for provider_arg in args.providers:
|
||||
# Split by comma if commas are present
|
||||
if "," in provider_arg:
|
||||
test_providers.extend(provider_arg.split(","))
|
||||
else:
|
||||
test_providers.append(provider_arg)
|
||||
else:
|
||||
# Default providers to test
|
||||
test_providers = PROVIDER_ORDER
|
||||
|
||||
for provider in test_providers:
|
||||
provider = provider.strip() # Remove any whitespace
|
||||
result_file = run_tests(provider, keyword=args.k)
|
||||
if result_file:
|
||||
# Parse and aggregate results
|
||||
parsed_results, providers_in_file, tests_in_file, run_timestamp = parse_results(result_file)
|
||||
all_results.update(parsed_results)
|
||||
for prov, models in providers_in_file.items():
|
||||
aggregated_providers[prov].update(models)
|
||||
if run_timestamp:
|
||||
provider_timestamps[prov] = run_timestamp
|
||||
aggregated_tests.update(tests_in_file)
|
||||
# 1. Determine the desired list and order of providers
|
||||
if args.providers:
|
||||
desired_providers = []
|
||||
for provider_arg in args.providers:
|
||||
desired_providers.extend([p.strip() for p in provider_arg.split(",")])
|
||||
else:
|
||||
# Use existing results
|
||||
provider_result_files = get_all_result_files_by_provider()
|
||||
desired_providers = DEFAULT_PROVIDERS # Use default order/list
|
||||
|
||||
for result_file in provider_result_files.values():
|
||||
# Parse and aggregate results
|
||||
parsed_results, providers_in_file, tests_in_file, run_timestamp = parse_results(result_file)
|
||||
all_results.update(parsed_results)
|
||||
for prov, models in providers_in_file.items():
|
||||
aggregated_providers[prov].update(models)
|
||||
if run_timestamp:
|
||||
provider_timestamps[prov] = run_timestamp
|
||||
aggregated_tests.update(tests_in_file)
|
||||
# 2. Run tests if requested (using the desired provider list)
|
||||
if args.run_tests:
|
||||
run_multiple_tests(desired_providers, args.k)
|
||||
|
||||
generate_report(all_results, aggregated_providers, aggregated_tests, provider_timestamps, args.output)
|
||||
for provider in desired_providers:
|
||||
# Construct the expected result file path directly
|
||||
result_file = RESULTS_DIR / f"{provider}.json"
|
||||
|
||||
if result_file.exists(): # Check if the specific file exists
|
||||
print(f"Loading results for {provider} from {result_file}")
|
||||
try:
|
||||
parsed_data = parse_results(result_file)
|
||||
parsed_results, providers_in_file, tests_in_file, run_timestamp = parsed_data
|
||||
all_results.update(parsed_results)
|
||||
aggregated_tests.update(tests_in_file)
|
||||
|
||||
# Add models for this provider, ensuring it's added in the correct report order
|
||||
if provider in providers_in_file:
|
||||
if provider not in final_providers_order:
|
||||
final_providers_order[provider] = set()
|
||||
final_providers_order[provider].update(providers_in_file[provider])
|
||||
if run_timestamp != "Unknown":
|
||||
provider_timestamps[provider] = run_timestamp
|
||||
else:
|
||||
print(
|
||||
f"Warning: Provider '{provider}' found in desired list but not within its result file data ({result_file})."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing results for provider {provider} from {result_file}: {e}")
|
||||
else:
|
||||
# Only print warning if we expected results (i.e., provider was in the desired list)
|
||||
print(f"Result file for desired provider '{provider}' not found at {result_file}. Skipping.")
|
||||
|
||||
# 5. Generate the report using the filtered & ordered results
|
||||
print(f"Final Provider Order for Report: {list(final_providers_order.keys())}")
|
||||
generate_report(all_results, final_providers_order, aggregated_tests, provider_timestamps, args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
BIN
tests/verifications/openai_api/fixtures/images/vision_test_1.jpg
Normal file
BIN
tests/verifications/openai_api/fixtures/images/vision_test_1.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 108 KiB |
BIN
tests/verifications/openai_api/fixtures/images/vision_test_2.jpg
Normal file
BIN
tests/verifications/openai_api/fixtures/images/vision_test_2.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 148 KiB |
BIN
tests/verifications/openai_api/fixtures/images/vision_test_3.jpg
Normal file
BIN
tests/verifications/openai_api/fixtures/images/vision_test_3.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 139 KiB |
|
@ -15,6 +15,52 @@ test_chat_basic:
|
|||
S?
|
||||
role: user
|
||||
output: Saturn
|
||||
test_chat_input_validation:
|
||||
test_name: test_chat_input_validation
|
||||
test_params:
|
||||
case:
|
||||
- case_id: "messages_missing"
|
||||
input:
|
||||
messages: []
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "messages_role_invalid"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: fake_role
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "tool_choice_invalid"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
tool_choice: invalid
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "tool_choice_no_tools"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
tool_choice: required
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
- case_id: "tools_type_invalid"
|
||||
input:
|
||||
messages:
|
||||
- content: Which planet do humans live on?
|
||||
role: user
|
||||
tools:
|
||||
- type: invalid
|
||||
output:
|
||||
error:
|
||||
status_code: 400
|
||||
test_chat_image:
|
||||
test_name: test_chat_image
|
||||
test_params:
|
||||
|
|
|
@ -4,19 +4,26 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from openai import APIError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.verifications.openai_api.fixtures.fixtures import _load_all_verification_configs
|
||||
from tests.verifications.openai_api.fixtures.fixtures import (
|
||||
_load_all_verification_configs,
|
||||
)
|
||||
from tests.verifications.openai_api.fixtures.load import load_test_cases
|
||||
|
||||
chat_completion_test_cases = load_test_cases("chat_completion")
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def case_id_generator(case):
|
||||
"""Generate a test ID from the case's 'case_id' field, or use a default."""
|
||||
|
@ -69,6 +76,21 @@ def get_base_test_name(request):
|
|||
return request.node.originalname
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multi_image_data():
|
||||
files = [
|
||||
THIS_DIR / "fixtures/images/vision_test_1.jpg",
|
||||
THIS_DIR / "fixtures/images/vision_test_2.jpg",
|
||||
THIS_DIR / "fixtures/images/vision_test_3.jpg",
|
||||
]
|
||||
encoded_files = []
|
||||
for file in files:
|
||||
with open(file, "rb") as image_file:
|
||||
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
encoded_files.append(f"data:image/jpeg;base64,{base64_data}")
|
||||
return encoded_files
|
||||
|
||||
|
||||
# --- Test Functions ---
|
||||
|
||||
|
||||
|
@ -115,6 +137,50 @@ def test_chat_streaming_basic(request, openai_client, model, provider, verificat
|
|||
assert case["output"].lower() in content.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_non_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
with pytest.raises(APIError) as e:
|
||||
openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
stream=False,
|
||||
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
|
||||
tools=case["input"]["tools"] if "tools" in case["input"] else None,
|
||||
)
|
||||
assert case["output"]["error"]["status_code"] == e.value.status_code
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
|
||||
ids=case_id_generator,
|
||||
)
|
||||
def test_chat_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
with pytest.raises(APIError) as e:
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=case["input"]["messages"],
|
||||
stream=True,
|
||||
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
|
||||
tools=case["input"]["tools"] if "tools" in case["input"] else None,
|
||||
)
|
||||
for _chunk in response:
|
||||
pass
|
||||
assert str(case["output"]["error"]["status_code"]) in e.value.message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
chat_completion_test_cases["test_chat_image"]["test_params"]["case"],
|
||||
|
@ -272,7 +338,6 @@ def test_chat_non_streaming_tool_choice_required(request, openai_client, model,
|
|||
tool_choice="required", # Force tool call
|
||||
stream=False,
|
||||
)
|
||||
print(response)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert len(response.choices[0].message.tool_calls) > 0, "Expected tool call when tool_choice='required'"
|
||||
|
@ -532,6 +597,86 @@ def test_chat_streaming_multi_turn_tool_calling(request, openai_client, model, p
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [False, True], ids=["stream=False", "stream=True"])
|
||||
def test_chat_multi_turn_multiple_images(
|
||||
request, openai_client, model, provider, verification_config, multi_image_data, stream
|
||||
):
|
||||
test_name_base = get_base_test_name(request)
|
||||
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||
|
||||
messages_turn1 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": multi_image_data[0],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": multi_image_data[1],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What furniture is in the first image that is not in the second image?",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# First API call
|
||||
response1 = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages_turn1,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
message_content1 = ""
|
||||
for chunk in response1:
|
||||
message_content1 += chunk.choices[0].delta.content or ""
|
||||
else:
|
||||
message_content1 = response1.choices[0].message.content
|
||||
assert len(message_content1) > 0
|
||||
assert any(expected in message_content1.lower().strip() for expected in {"chair", "table"}), message_content1
|
||||
|
||||
# Prepare messages for the second turn
|
||||
messages_turn2 = messages_turn1 + [
|
||||
{"role": "assistant", "content": message_content1},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": multi_image_data[2],
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What is in this image that is also in the first image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
# Second API call
|
||||
response2 = openai_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages_turn2,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
message_content2 = ""
|
||||
for chunk in response2:
|
||||
message_content2 += chunk.choices[0].delta.content or ""
|
||||
else:
|
||||
message_content2 = response2.choices[0].message.content
|
||||
assert len(message_content2) > 0
|
||||
assert any(expected in message_content2.lower().strip() for expected in {"bed"}), message_content2
|
||||
|
||||
|
||||
# --- Helper functions (structured output validation) ---
|
||||
|
||||
|
||||
|
|
File diff suppressed because one or more lines are too long
1097
tests/verifications/test_results/meta_reference.json
Normal file
1097
tests/verifications/test_results/meta_reference.json
Normal file
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because one or more lines are too long
19
uv.lock
generated
19
uv.lock
generated
|
@ -1,4 +1,5 @@
|
|||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.10"
|
||||
resolution-markers = [
|
||||
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
|
@ -1410,6 +1411,7 @@ dev = [
|
|||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "pytest-html" },
|
||||
{ name = "pytest-json-report" },
|
||||
{ name = "ruamel-yaml" },
|
||||
{ name = "ruff" },
|
||||
{ name = "types-requests" },
|
||||
|
@ -1456,6 +1458,7 @@ unit = [
|
|||
{ name = "aiosqlite" },
|
||||
{ name = "chardet" },
|
||||
{ name = "openai" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
||||
{ name = "pypdf" },
|
||||
{ name = "qdrant-client" },
|
||||
{ name = "sqlite-vec" },
|
||||
|
@ -1489,6 +1492,7 @@ requires-dist = [
|
|||
{ name = "openai", marker = "extra == 'test'" },
|
||||
{ name = "openai", marker = "extra == 'unit'" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'unit'" },
|
||||
{ name = "opentelemetry-sdk", marker = "extra == 'test'" },
|
||||
{ name = "pandas", marker = "extra == 'ui'" },
|
||||
{ name = "pillow" },
|
||||
|
@ -1502,6 +1506,7 @@ requires-dist = [
|
|||
{ name = "pytest-asyncio", marker = "extra == 'dev'" },
|
||||
{ name = "pytest-cov", marker = "extra == 'dev'" },
|
||||
{ name = "pytest-html", marker = "extra == 'dev'" },
|
||||
{ name = "pytest-json-report", marker = "extra == 'dev'" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "qdrant-client", marker = "extra == 'unit'" },
|
||||
{ name = "requests" },
|
||||
|
@ -1531,6 +1536,7 @@ requires-dist = [
|
|||
{ name = "types-setuptools", marker = "extra == 'dev'" },
|
||||
{ name = "uvicorn", marker = "extra == 'dev'" },
|
||||
]
|
||||
provides-extras = ["dev", "unit", "test", "docs", "codegen", "ui"]
|
||||
|
||||
[[package]]
|
||||
name = "llama-stack-client"
|
||||
|
@ -2740,6 +2746,19 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/c8/c7/c160021cbecd956cc1a6f79e5fe155f7868b2e5b848f1320dad0b3e3122f/pytest_html-4.1.1-py3-none-any.whl", hash = "sha256:c8152cea03bd4e9bee6d525573b67bbc6622967b72b9628dda0ea3e2a0b5dd71", size = 23491 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-json-report"
|
||||
version = "1.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-metadata" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4f/d3/765dae9712fcd68d820338908c1337e077d5fdadccd5cacf95b9b0bea278/pytest-json-report-1.5.0.tar.gz", hash = "sha256:2dde3c647851a19b5f3700729e8310a6e66efb2077d674f27ddea3d34dc615de", size = 21241 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/81/35/d07400c715bf8a88aa0c1ee9c9eb6050ca7fe5b39981f0eea773feeb0681/pytest_json_report-1.5.0-py3-none-any.whl", hash = "sha256:9897b68c910b12a2e48dd849f9a284b2c79a732a8a9cb398452ddd23d3c8c325", size = 13222 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-metadata"
|
||||
version = "3.1.1"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue