Merge remote-tracking branch 'origin/main' into k8s_demo

This commit is contained in:
Kai Wu 2025-07-29 09:00:45 -07:00
commit 95d25ddfe2
101 changed files with 3309 additions and 5108 deletions

View file

@ -117,17 +117,13 @@ jobs:
EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag" EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag"
if [ "${{ matrix.provider }}" == "ollama" ]; then if [ "${{ matrix.provider }}" == "ollama" ]; then
export ENABLE_OLLAMA="ollama"
export OLLAMA_URL="http://0.0.0.0:11434" export OLLAMA_URL="http://0.0.0.0:11434"
export OLLAMA_INFERENCE_MODEL="llama3.2:3b-instruct-fp16" export TEXT_MODEL=ollama/llama3.2:3b-instruct-fp16
export TEXT_MODEL=ollama/$OLLAMA_INFERENCE_MODEL export SAFETY_MODEL="ollama/llama-guard3:1b"
export SAFETY_MODEL="llama-guard3:1b" EXTRA_PARAMS="--safety-shield=llama-guard"
EXTRA_PARAMS="--safety-shield=$SAFETY_MODEL"
else else
export ENABLE_VLLM="vllm"
export VLLM_URL="http://localhost:8000/v1" export VLLM_URL="http://localhost:8000/v1"
export VLLM_INFERENCE_MODEL="meta-llama/Llama-3.2-1B-Instruct" export TEXT_MODEL=vllm/meta-llama/Llama-3.2-1B-Instruct
export TEXT_MODEL=vllm/$VLLM_INFERENCE_MODEL
# TODO: remove the not(test_inference_store_tool_calls) once we can get the tool called consistently # TODO: remove the not(test_inference_store_tool_calls) once we can get the tool called consistently
EXTRA_PARAMS= EXTRA_PARAMS=
EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls" EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls"

View file

@ -14,10 +14,18 @@ concurrency:
jobs: jobs:
pre-commit: pre-commit:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
# For dependabot PRs, we need to checkout with a token that can push changes
token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }}
# Fetch full history for dependabot PRs to allow commits
fetch-depth: ${{ github.actor == 'dependabot[bot]' && 0 || 1 }}
- name: Set up Python - name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
@ -29,15 +37,45 @@ jobs:
.pre-commit-config.yaml .pre-commit-config.yaml
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true
env: env:
SKIP: no-commit-to-branch SKIP: no-commit-to-branch
RUFF_OUTPUT_FORMAT: github RUFF_OUTPUT_FORMAT: github
- name: Debug
run: |
echo "github.ref: ${{ github.ref }}"
echo "github.actor: ${{ github.actor }}"
- name: Commit changes for dependabot PRs
if: github.actor == 'dependabot[bot]'
run: |
if ! git diff --exit-code || [ -n "$(git ls-files --others --exclude-standard)" ]; then
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
# Ensure we're on the correct branch
git checkout -B ${{ github.head_ref }}
git add -A
git commit -m "Apply pre-commit fixes"
# Pull latest changes from the PR branch and rebase our commit on top
git pull --rebase origin ${{ github.head_ref }}
# Push to the PR branch
git push origin ${{ github.head_ref }}
echo "Pre-commit fixes committed and pushed"
else
echo "No changes to commit"
fi
- name: Verify if there are any diff files after pre-commit - name: Verify if there are any diff files after pre-commit
if: github.actor != 'dependabot[bot]'
run: | run: |
git diff --exit-code || (echo "There are uncommitted changes, run pre-commit locally and commit again" && exit 1) git diff --exit-code || (echo "There are uncommitted changes, run pre-commit locally and commit again" && exit 1)
- name: Verify if there are any new files after pre-commit - name: Verify if there are any new files after pre-commit
if: github.actor != 'dependabot[bot]'
run: | run: |
unstaged_files=$(git ls-files --others --exclude-standard) unstaged_files=$(git ls-files --others --exclude-standard)
if [ -n "$unstaged_files" ]; then if [ -n "$unstaged_files" ]; then

View file

@ -13,7 +13,8 @@ on:
- 'uv.lock' - 'uv.lock'
- 'pyproject.toml' - 'pyproject.toml'
- 'requirements.txt' - 'requirements.txt'
- '.github/workflows/test-external-providers-module.yml' # This workflow - 'tests/external/*'
- '.github/workflows/test-external-provider-module.yml' # This workflow
jobs: jobs:
test-external-providers-from-module: test-external-providers-from-module:
@ -52,6 +53,7 @@ jobs:
if: ${{ matrix.image-type }} == 'venv' if: ${{ matrix.image-type }} == 'venv'
env: env:
INFERENCE_MODEL: "llama3.2:3b-instruct-fp16" INFERENCE_MODEL: "llama3.2:3b-instruct-fp16"
LLAMA_STACK_LOG_FILE: "server.log"
run: | run: |
# Use the virtual environment created by the build step (name comes from build config) # Use the virtual environment created by the build step (name comes from build config)
source ramalama-stack-test/bin/activate source ramalama-stack-test/bin/activate
@ -72,3 +74,12 @@ jobs:
echo "Provider failed to load" echo "Provider failed to load"
cat server.log cat server.log
exit 1 exit 1
- name: Upload all logs to artifacts
if: ${{ always() }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-external-provider-module-test
path: |
*.log
retention-days: 1

View file

@ -13,6 +13,7 @@ on:
- 'uv.lock' - 'uv.lock'
- 'pyproject.toml' - 'pyproject.toml'
- 'requirements.txt' - 'requirements.txt'
- 'tests/external/*'
- '.github/workflows/test-external.yml' # This workflow - '.github/workflows/test-external.yml' # This workflow
jobs: jobs:
@ -52,6 +53,7 @@ jobs:
if: ${{ matrix.image-type }} == 'venv' if: ${{ matrix.image-type }} == 'venv'
env: env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
LLAMA_STACK_LOG_FILE: "server.log"
run: | run: |
# Use the virtual environment created by the build step (name comes from build config) # Use the virtual environment created by the build step (name comes from build config)
source ci-test/bin/activate source ci-test/bin/activate
@ -75,3 +77,12 @@ jobs:
- name: Test external API - name: Test external API
run: | run: |
curl -sSf http://localhost:8321/v1/weather/locations curl -sSf http://localhost:8321/v1/weather/locations
- name: Upload all logs to artifacts
if: ${{ always() }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-external-test
path: |
*.log
retention-days: 1

View file

@ -35,6 +35,8 @@ jobs:
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner
with:
python-version: ${{ matrix.python }}
- name: Run unit tests - name: Run unit tests
run: | run: |

View file

@ -19,7 +19,6 @@ repos:
- id: check-yaml - id: check-yaml
args: ["--unsafe"] args: ["--unsafe"]
- id: detect-private-key - id: detect-private-key
- id: requirements-txt-fixer
- id: mixed-line-ending - id: mixed-line-ending
args: [--fix=lf] # Forces to replace line ending by LF (line feed) args: [--fix=lf] # Forces to replace line ending by LF (line feed)
- id: check-executables-have-shebangs - id: check-executables-have-shebangs
@ -56,14 +55,6 @@ repos:
rev: 0.7.20 rev: 0.7.20
hooks: hooks:
- id: uv-lock - id: uv-lock
- id: uv-export
args: [
"--frozen",
"--no-hashes",
"--no-emit-project",
"--no-default-groups",
"--output-file=requirements.txt"
]
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.1 rev: v1.16.1

View file

@ -9770,7 +9770,7 @@
{ {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
} }
} }
], ],
@ -9821,13 +9821,17 @@
}, },
{ {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" "$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
},
{
"$ref": "#/components/schemas/OpenAIFile"
} }
], ],
"discriminator": { "discriminator": {
"propertyName": "type", "propertyName": "type",
"mapping": { "mapping": {
"text": "#/components/schemas/OpenAIChatCompletionContentPartTextParam", "text": "#/components/schemas/OpenAIChatCompletionContentPartTextParam",
"image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" "image_url": "#/components/schemas/OpenAIChatCompletionContentPartImageParam",
"file": "#/components/schemas/OpenAIFile"
} }
} }
}, },
@ -9955,7 +9959,7 @@
{ {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
} }
} }
], ],
@ -9974,6 +9978,41 @@
"title": "OpenAIDeveloperMessageParam", "title": "OpenAIDeveloperMessageParam",
"description": "A message from the developer in an OpenAI-compatible chat completion request." "description": "A message from the developer in an OpenAI-compatible chat completion request."
}, },
"OpenAIFile": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "file",
"default": "file"
},
"file": {
"$ref": "#/components/schemas/OpenAIFileFile"
}
},
"additionalProperties": false,
"required": [
"type",
"file"
],
"title": "OpenAIFile"
},
"OpenAIFileFile": {
"type": "object",
"properties": {
"file_data": {
"type": "string"
},
"file_id": {
"type": "string"
},
"filename": {
"type": "string"
}
},
"additionalProperties": false,
"title": "OpenAIFileFile"
},
"OpenAIImageURL": { "OpenAIImageURL": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -10036,7 +10075,7 @@
{ {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
} }
} }
], ],
@ -10107,7 +10146,7 @@
{ {
"type": "array", "type": "array",
"items": { "items": {
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartParam" "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
} }
} }
], ],

View file

@ -6895,7 +6895,7 @@ components:
- type: string - type: string
- type: array - type: array
items: items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
description: The content of the model's response description: The content of the model's response
name: name:
type: string type: string
@ -6934,11 +6934,13 @@ components:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' - $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
- $ref: '#/components/schemas/OpenAIFile'
discriminator: discriminator:
propertyName: type propertyName: type
mapping: mapping:
text: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' text: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
image_url: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' image_url: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
file: '#/components/schemas/OpenAIFile'
OpenAIChatCompletionContentPartTextParam: OpenAIChatCompletionContentPartTextParam:
type: object type: object
properties: properties:
@ -7037,7 +7039,7 @@ components:
- type: string - type: string
- type: array - type: array
items: items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
description: The content of the developer message description: The content of the developer message
name: name:
type: string type: string
@ -7050,6 +7052,31 @@ components:
title: OpenAIDeveloperMessageParam title: OpenAIDeveloperMessageParam
description: >- description: >-
A message from the developer in an OpenAI-compatible chat completion request. A message from the developer in an OpenAI-compatible chat completion request.
OpenAIFile:
type: object
properties:
type:
type: string
const: file
default: file
file:
$ref: '#/components/schemas/OpenAIFileFile'
additionalProperties: false
required:
- type
- file
title: OpenAIFile
OpenAIFileFile:
type: object
properties:
file_data:
type: string
file_id:
type: string
filename:
type: string
additionalProperties: false
title: OpenAIFileFile
OpenAIImageURL: OpenAIImageURL:
type: object type: object
properties: properties:
@ -7090,7 +7117,7 @@ components:
- type: string - type: string
- type: array - type: array
items: items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
description: >- description: >-
The content of the "system prompt". If multiple system messages are provided, The content of the "system prompt". If multiple system messages are provided,
they are concatenated. The underlying Llama Stack code may also add other they are concatenated. The underlying Llama Stack code may also add other
@ -7148,7 +7175,7 @@ components:
- type: string - type: string
- type: array - type: array
items: items:
$ref: '#/components/schemas/OpenAIChatCompletionContentPartParam' $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
description: The response content from the tool description: The response content from the tool
additionalProperties: false additionalProperties: false
required: required:

View file

@ -249,12 +249,6 @@
], ],
"source": [ "source": [
"from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient\n", "from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient\n",
"import os\n",
"\n",
"os.environ[\"ENABLE_OLLAMA\"] = \"ollama\"\n",
"os.environ[\"OLLAMA_INFERENCE_MODEL\"] = \"llama3.2:3b\"\n",
"os.environ[\"OLLAMA_EMBEDDING_MODEL\"] = \"all-minilm:l6-v2\"\n",
"os.environ[\"OLLAMA_EMBEDDING_DIMENSION\"] = \"384\"\n",
"\n", "\n",
"vector_db_id = \"my_demo_vector_db\"\n", "vector_db_id = \"my_demo_vector_db\"\n",
"client = LlamaStackClient(base_url=\"http://0.0.0.0:8321\")\n", "client = LlamaStackClient(base_url=\"http://0.0.0.0:8321\")\n",

View file

@ -13,7 +13,7 @@ llama stack build --template starter --image-type venv
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient( client = LlamaStackAsLibraryClient(
"ollama", "starter",
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here. # provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]}, provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
) )

View file

@ -40,16 +40,16 @@ The following environment variables can be configured:
The following models are available by default: The following models are available by default:
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)` - `meta/llama3-8b-instruct `
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)` - `meta/llama3-70b-instruct `
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)` - `meta/llama-3.1-8b-instruct `
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)` - `meta/llama-3.1-70b-instruct `
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)` - `meta/llama-3.1-405b-instruct `
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)` - `meta/llama-3.2-1b-instruct `
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)` - `meta/llama-3.2-3b-instruct `
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)` - `meta/llama-3.2-11b-vision-instruct `
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)` - `meta/llama-3.2-90b-vision-instruct `
- `meta/llama-3.3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)` - `meta/llama-3.3-70b-instruct `
- `nvidia/llama-3.2-nv-embedqa-1b-v2 ` - `nvidia/llama-3.2-nv-embedqa-1b-v2 `
- `nvidia/nv-embedqa-e5-v5 ` - `nvidia/nv-embedqa-e5-v5 `
- `nvidia/nv-embedqa-mistral-7b-v2 ` - `nvidia/nv-embedqa-mistral-7b-v2 `

View file

@ -158,7 +158,7 @@ export ENABLE_PGVECTOR=__disabled__
The starter distribution uses several patterns for provider IDs: The starter distribution uses several patterns for provider IDs:
1. **Direct provider IDs**: `faiss`, `ollama`, `vllm` 1. **Direct provider IDs**: `faiss`, `ollama`, `vllm`
2. **Environment-based provider IDs**: `${env.ENABLE_SQLITE_VEC+sqlite-vec}` 2. **Environment-based provider IDs**: `${env.ENABLE_SQLITE_VEC:+sqlite-vec}`
3. **Model-based provider IDs**: `${env.OLLAMA_INFERENCE_MODEL:__disabled__}` 3. **Model-based provider IDs**: `${env.OLLAMA_INFERENCE_MODEL:__disabled__}`
When using the `+` pattern (like `${env.ENABLE_SQLITE_VEC+sqlite-vec}`), the provider is enabled by default and can be disabled by setting the environment variable to `__disabled__`. When using the `+` pattern (like `${env.ENABLE_SQLITE_VEC+sqlite-vec}`), the provider is enabled by default and can be disabled by setting the environment variable to `__disabled__`.

View file

@ -59,7 +59,7 @@ Now let's build and run the Llama Stack config for Ollama.
We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables. We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables.
```bash ```bash
ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL="llama3.2:3b" llama stack build --template starter --image-type venv --run llama stack build --template starter --image-type venv --run
``` ```
::: :::
:::{tab-item} Using `conda` :::{tab-item} Using `conda`
@ -70,7 +70,7 @@ which defines the providers and their settings.
Now let's build and run the Llama Stack config for Ollama. Now let's build and run the Llama Stack config for Ollama.
```bash ```bash
ENABLE_OLLAMA=ollama INFERENCE_MODEL="llama3.2:3b" llama stack build --template starter --image-type conda --run llama stack build --template starter --image-type conda --run
``` ```
::: :::
:::{tab-item} Using a Container :::{tab-item} Using a Container
@ -80,8 +80,6 @@ component that works with different inference providers out of the box. For this
configurations, please check out [this guide](../distributions/building_distro.md). configurations, please check out [this guide](../distributions/building_distro.md).
First lets setup some environment variables and create a local directory to mount into the containers file system. First lets setup some environment variables and create a local directory to mount into the containers file system.
```bash ```bash
export INFERENCE_MODEL="llama3.2:3b"
export ENABLE_OLLAMA=ollama
export LLAMA_STACK_PORT=8321 export LLAMA_STACK_PORT=8321
mkdir -p ~/.llama mkdir -p ~/.llama
``` ```
@ -94,7 +92,6 @@ docker run -it \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
llamastack/distribution-starter \ llamastack/distribution-starter \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env OLLAMA_URL=http://host.docker.internal:11434 --env OLLAMA_URL=http://host.docker.internal:11434
``` ```
Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with
@ -116,7 +113,6 @@ docker run -it \
--network=host \ --network=host \
llamastack/distribution-starter \ llamastack/distribution-starter \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env OLLAMA_URL=http://localhost:11434 --env OLLAMA_URL=http://localhost:11434
``` ```
::: :::

View file

@ -19,7 +19,7 @@ ollama run llama3.2:3b --keepalive 60m
#### Step 2: Run the Llama Stack server #### Step 2: Run the Llama Stack server
We will use `uv` to run the Llama Stack server. We will use `uv` to run the Llama Stack server.
```bash ```bash
ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run uv run --with llama-stack llama stack build --template starter --image-type venv --run
``` ```
#### Step 3: Run the demo #### Step 3: Run the demo
Now open up a new terminal and copy the following script into a file named `demo_script.py`. Now open up a new terminal and copy the following script into a file named `demo_script.py`.

View file

@ -12,8 +12,7 @@ To enable external providers, you need to add `module` into your build yaml, all
an example entry in your build.yaml should look like: an example entry in your build.yaml should look like:
``` ```
- provider_id: ramalama - provider_type: remote::ramalama
provider_type: remote::ramalama
module: ramalama_stack module: ramalama_stack
``` ```
@ -255,8 +254,7 @@ distribution_spec:
container_image: null container_image: null
providers: providers:
inference: inference:
- provider_id: ramalama - provider_type: remote::ramalama
provider_type: remote::ramalama
module: ramalama_stack==0.3.0a0 module: ramalama_stack==0.3.0a0
image_type: venv image_type: venv
image_name: null image_name: null

View file

@ -13,7 +13,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
## Sample Configuration ## Sample Configuration
```yaml ```yaml
api_key: ${env.ANTHROPIC_API_KEY} api_key: ${env.ANTHROPIC_API_KEY:=}
``` ```

View file

@ -15,7 +15,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
```yaml ```yaml
base_url: https://api.cerebras.ai base_url: https://api.cerebras.ai
api_key: ${env.CEREBRAS_API_KEY} api_key: ${env.CEREBRAS_API_KEY:=}
``` ```

View file

@ -14,8 +14,8 @@ Databricks inference provider for running models on Databricks' unified analytic
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.DATABRICKS_URL} url: ${env.DATABRICKS_URL:=}
api_token: ${env.DATABRICKS_API_TOKEN} api_token: ${env.DATABRICKS_API_TOKEN:=}
``` ```

View file

@ -16,7 +16,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
```yaml ```yaml
url: https://api.fireworks.ai/inference/v1 url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY} api_key: ${env.FIREWORKS_API_KEY:=}
``` ```

View file

@ -13,7 +13,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
## Sample Configuration ## Sample Configuration
```yaml ```yaml
api_key: ${env.GEMINI_API_KEY} api_key: ${env.GEMINI_API_KEY:=}
``` ```

View file

@ -15,7 +15,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
```yaml ```yaml
url: https://api.groq.com url: https://api.groq.com
api_key: ${env.GROQ_API_KEY} api_key: ${env.GROQ_API_KEY:=}
``` ```

View file

@ -9,11 +9,13 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `api_key` | `str \| None` | No | | API key for OpenAI models | | `api_key` | `str \| None` | No | | API key for OpenAI models |
| `base_url` | `<class 'str'>` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
api_key: ${env.OPENAI_API_KEY} api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
``` ```

View file

@ -15,7 +15,7 @@ SambaNova OpenAI-compatible provider for using SambaNova models with OpenAI API
```yaml ```yaml
openai_compat_api_base: https://api.sambanova.ai/v1 openai_compat_api_base: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY} api_key: ${env.SAMBANOVA_API_KEY:=}
``` ```

View file

@ -15,7 +15,7 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
```yaml ```yaml
url: https://api.sambanova.ai/v1 url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY} api_key: ${env.SAMBANOVA_API_KEY:=}
``` ```

View file

@ -13,7 +13,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving.
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.TGI_URL} url: ${env.TGI_URL:=}
``` ```

View file

@ -16,7 +16,7 @@ Together AI inference provider for open-source models and collaborative AI devel
```yaml ```yaml
url: https://api.together.xyz/v1 url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY} api_key: ${env.TOGETHER_API_KEY:=}
``` ```

View file

@ -15,7 +15,7 @@ SambaNova's safety provider for content moderation and safety filtering.
```yaml ```yaml
url: https://api.sambanova.ai/v1 url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY} api_key: ${env.SAMBANOVA_API_KEY:=}
``` ```

View file

@ -455,8 +455,21 @@ class OpenAIChatCompletionContentPartImageParam(BaseModel):
image_url: OpenAIImageURL image_url: OpenAIImageURL
@json_schema_type
class OpenAIFileFile(BaseModel):
file_data: str | None = None
file_id: str | None = None
filename: str | None = None
@json_schema_type
class OpenAIFile(BaseModel):
type: Literal["file"] = "file"
file: OpenAIFileFile
OpenAIChatCompletionContentPartParam = Annotated[ OpenAIChatCompletionContentPartParam = Annotated[
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam") register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
@ -464,6 +477,8 @@ register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletion
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam] OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
@json_schema_type @json_schema_type
class OpenAIUserMessageParam(BaseModel): class OpenAIUserMessageParam(BaseModel):
@ -489,7 +504,7 @@ class OpenAISystemMessageParam(BaseModel):
""" """
role: Literal["system"] = "system" role: Literal["system"] = "system"
content: OpenAIChatCompletionMessageContent content: OpenAIChatCompletionTextOnlyMessageContent
name: str | None = None name: str | None = None
@ -518,7 +533,7 @@ class OpenAIAssistantMessageParam(BaseModel):
""" """
role: Literal["assistant"] = "assistant" role: Literal["assistant"] = "assistant"
content: OpenAIChatCompletionMessageContent | None = None content: OpenAIChatCompletionTextOnlyMessageContent | None = None
name: str | None = None name: str | None = None
tool_calls: list[OpenAIChatCompletionToolCall] | None = None tool_calls: list[OpenAIChatCompletionToolCall] | None = None
@ -534,7 +549,7 @@ class OpenAIToolMessageParam(BaseModel):
role: Literal["tool"] = "tool" role: Literal["tool"] = "tool"
tool_call_id: str tool_call_id: str
content: OpenAIChatCompletionMessageContent content: OpenAIChatCompletionTextOnlyMessageContent
@json_schema_type @json_schema_type
@ -547,7 +562,7 @@ class OpenAIDeveloperMessageParam(BaseModel):
""" """
role: Literal["developer"] = "developer" role: Literal["developer"] = "developer"
content: OpenAIChatCompletionMessageContent content: OpenAIChatCompletionTextOnlyMessageContent
name: str | None = None name: str | None = None

View file

@ -31,6 +31,7 @@ from llama_stack.distribution.build import (
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BuildConfig, BuildConfig,
BuildProvider,
DistributionSpec, DistributionSpec,
Provider, Provider,
StackRunConfig, StackRunConfig,
@ -94,7 +95,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
) )
sys.exit(1) sys.exit(1)
elif args.providers: elif args.providers:
provider_list: dict[str, list[Provider]] = dict() provider_list: dict[str, list[BuildProvider]] = dict()
for api_provider in args.providers.split(","): for api_provider in args.providers.split(","):
if "=" not in api_provider: if "=" not in api_provider:
cprint( cprint(
@ -113,10 +114,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
) )
sys.exit(1) sys.exit(1)
if provider_type in providers_for_api: if provider_type in providers_for_api:
provider = Provider( provider = BuildProvider(
provider_type=provider_type, provider_type=provider_type,
provider_id=provider_type.split("::")[1],
config={},
module=None, module=None,
) )
provider_list.setdefault(api, []).append(provider) provider_list.setdefault(api, []).append(provider)
@ -189,7 +188,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr) cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
providers: dict[str, list[Provider]] = dict() providers: dict[str, list[BuildProvider]] = dict()
for api, providers_for_api in get_provider_registry().items(): for api, providers_for_api in get_provider_registry().items():
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")] available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
if not available_providers: if not available_providers:
@ -204,7 +203,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
), ),
) )
providers[api.value] = api_provider string_providers = api_provider.split(" ")
for provider in string_providers:
providers.setdefault(api.value, []).append(BuildProvider(provider_type=provider))
description = prompt( description = prompt(
"\n > (Optional) Enter a short description for your Llama Stack: ", "\n > (Optional) Enter a short description for your Llama Stack: ",
@ -307,7 +309,7 @@ def _generate_run_config(
providers = build_config.distribution_spec.providers[api] providers = build_config.distribution_spec.providers[api]
for provider in providers: for provider in providers:
pid = provider.provider_id pid = provider.provider_type.split("::")[-1]
p = provider_registry[Api(api)][provider.provider_type] p = provider_registry[Api(api)][provider.provider_type]
if p.deprecation_error: if p.deprecation_error:

View file

@ -18,10 +18,6 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead # mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-} USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
# Mount command for cache container .cache, can be overridden by the user if needed
MOUNT_CACHE=${MOUNT_CACHE:-"--mount=type=cache,id=llama-stack-cache,target=/root/.cache"}
# Path to the run.yaml file in the container # Path to the run.yaml file in the container
RUN_CONFIG_PATH=/app/run.yaml RUN_CONFIG_PATH=/app/run.yaml
@ -176,18 +172,13 @@ RUN pip install uv
EOF EOF
fi fi
# Set the link mode to copy so that uv doesn't attempt to symlink to the cache directory
add_to_container << EOF
ENV UV_LINK_MODE=copy
EOF
# Add pip dependencies first since llama-stack is what will change most often # Add pip dependencies first since llama-stack is what will change most often
# so we can reuse layers. # so we can reuse layers.
if [ -n "$normal_deps" ]; then if [ -n "$normal_deps" ]; then
read -ra pip_args <<< "$normal_deps" read -ra pip_args <<< "$normal_deps"
quoted_deps=$(printf " %q" "${pip_args[@]}") quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container << EOF add_to_container << EOF
RUN $MOUNT_CACHE uv pip install $quoted_deps RUN uv pip install --no-cache $quoted_deps
EOF EOF
fi fi
@ -197,7 +188,7 @@ if [ -n "$optional_deps" ]; then
read -ra pip_args <<< "$part" read -ra pip_args <<< "$part"
quoted_deps=$(printf " %q" "${pip_args[@]}") quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container <<EOF add_to_container <<EOF
RUN $MOUNT_CACHE uv pip install $quoted_deps RUN uv pip install --no-cache $quoted_deps
EOF EOF
done done
fi fi
@ -208,10 +199,10 @@ if [ -n "$external_provider_deps" ]; then
read -ra pip_args <<< "$part" read -ra pip_args <<< "$part"
quoted_deps=$(printf " %q" "${pip_args[@]}") quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container <<EOF add_to_container <<EOF
RUN $MOUNT_CACHE uv pip install $quoted_deps RUN uv pip install --no-cache $quoted_deps
EOF EOF
add_to_container <<EOF add_to_container <<EOF
RUN python3 - <<PYTHON | $MOUNT_CACHE uv pip install -r - RUN python3 - <<PYTHON | uv pip install --no-cache -r -
import importlib import importlib
import sys import sys
@ -293,7 +284,7 @@ COPY $dir $mount_point
EOF EOF
fi fi
add_to_container << EOF add_to_container << EOF
RUN $MOUNT_CACHE uv pip install -e $mount_point RUN uv pip install --no-cache -e $mount_point
EOF EOF
} }
@ -308,10 +299,10 @@ else
if [ -n "$TEST_PYPI_VERSION" ]; then if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first # these packages are damaged in test-pypi, so install them first
add_to_container << EOF add_to_container << EOF
RUN $MOUNT_CACHE uv pip install fastapi libcst RUN uv pip install --no-cache fastapi libcst
EOF EOF
add_to_container << EOF add_to_container << EOF
RUN $MOUNT_CACHE uv pip install --extra-index-url https://test.pypi.org/simple/ \ RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \ --index-strategy unsafe-best-match \
llama-stack==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
@ -323,7 +314,7 @@ EOF
SPEC_VERSION="llama-stack" SPEC_VERSION="llama-stack"
fi fi
add_to_container << EOF add_to_container << EOF
RUN $MOUNT_CACHE uv pip install $SPEC_VERSION RUN uv pip install --no-cache $SPEC_VERSION
EOF EOF
fi fi
fi fi

View file

@ -100,11 +100,12 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
break break
logger.info(f"> Configuring provider `({provider.provider_type})`") logger.info(f"> Configuring provider `({provider.provider_type})`")
pid = provider.provider_type.split("::")[-1]
updated_providers.append( updated_providers.append(
configure_single_provider( configure_single_provider(
provider_registry[api], provider_registry[api],
Provider( Provider(
provider_id=(f"{provider.provider_id}-{i:02d}" if len(plist) > 1 else provider.provider_id), provider_id=(f"{pid}-{i:02d}" if len(plist) > 1 else pid),
provider_type=provider.provider_type, provider_type=provider.provider_type,
config={}, config={},
), ),

View file

@ -154,13 +154,27 @@ class Provider(BaseModel):
) )
class BuildProvider(BaseModel):
provider_type: str
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the external provider module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
class DistributionSpec(BaseModel): class DistributionSpec(BaseModel):
description: str | None = Field( description: str | None = Field(
default="", default="",
description="Description of the distribution", description="Description of the distribution",
) )
container_image: str | None = None container_image: str | None = None
providers: dict[str, list[Provider]] = Field( providers: dict[str, list[BuildProvider]] = Field(
default_factory=dict, default_factory=dict,
description=""" description="""
Provider Types for each of the APIs provided by this distribution. If you Provider Types for each of the APIs provided by this distribution. If you

View file

@ -33,7 +33,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec from llama_stack.distribution.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec
from llama_stack.distribution.request_headers import ( from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR, PROVIDER_DATA_VAR,
request_provider_data_context, request_provider_data_context,
@ -249,9 +249,16 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
file=sys.stderr, file=sys.stderr,
) )
if self.config_path_or_template_name.endswith(".yaml"): if self.config_path_or_template_name.endswith(".yaml"):
providers: dict[str, list[BuildProvider]] = {}
for api, run_providers in self.config.providers.items():
for provider in run_providers:
providers.setdefault(api, []).append(
BuildProvider(provider_type=provider.provider_type, module=provider.module)
)
providers = dict(providers)
build_config = BuildConfig( build_config = BuildConfig(
distribution_spec=DistributionSpec( distribution_spec=DistributionSpec(
providers=self.config.providers, providers=providers,
), ),
external_providers_dir=self.config.external_providers_dir, external_providers_dir=self.config.external_providers_dir,
) )

View file

@ -25,7 +25,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def refresh(self) -> None: async def refresh(self) -> None:
for provider_id, provider in self.impls_by_provider_id.items(): for provider_id, provider in self.impls_by_provider_id.items():
refresh = await provider.should_refresh_models() refresh = await provider.should_refresh_models()
if not (refresh or provider_id in self.listed_providers): refresh = refresh or provider_id not in self.listed_providers
if not refresh:
continue continue
try: try:
@ -138,6 +139,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# avoid overwriting a non-provider-registered model entry # avoid overwriting a non-provider-registered model entry
continue continue
if model.identifier == model.provider_resource_id:
model.identifier = f"{provider_id}/{model.provider_resource_id}"
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object( await self.register_object(
ModelWithOwner( ModelWithOwner(

View file

@ -611,11 +611,8 @@ def extract_path_params(route: str) -> list[str]:
def remove_disabled_providers(obj): def remove_disabled_providers(obj):
if isinstance(obj, dict): if isinstance(obj, dict):
if ( keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
obj.get("provider_id") == "__disabled__" if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
or obj.get("shield_id") == "__disabled__"
or obj.get("provider_model_id") == "__disabled__"
):
return None return None
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None} return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
elif isinstance(obj, list): elif isinstance(obj, list):

View file

@ -105,23 +105,10 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method) method = getattr(impls[api], register_method)
for obj in objects: for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}") logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# Do not register models on disabled providers
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue
# In complex templates, like our starter template, we may have dynamic model ids
# given by environment variables. This allows those environment variables to have
# a default value of __disabled__ to skip registration of the model if not set.
if (
hasattr(obj, "provider_model_id")
and obj.provider_model_id is not None
and "__disabled__" in obj.provider_model_id
):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
continue
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__": # Do not register models on disabled providers
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.") if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
continue continue
# we want to maintain the type information in arguments to method. # we want to maintain the type information in arguments to method.
@ -331,8 +318,10 @@ async def construct_stack(
await register_resources(run_config, impls) await register_resources(run_config, impls)
await refresh_registry_once(impls)
global REGISTRY_REFRESH_TASK global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls)) REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
def cb(task): def cb(task):
import traceback import traceback
@ -368,12 +357,18 @@ async def shutdown_stack(impls: dict[Api, Any]):
REGISTRY_REFRESH_TASK.cancel() REGISTRY_REFRESH_TASK.cancel()
async def refresh_registry(impls: dict[Api, Any]): async def refresh_registry_once(impls: dict[Api, Any]):
logger.debug("refreshing registry")
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
while True:
for routing_table in routing_tables: for routing_table in routing_tables:
await routing_table.refresh() await routing_table.refresh()
async def refresh_registry_task(impls: dict[Api, Any]):
logger.info("starting registry refresh task")
while True:
await refresh_registry_once(impls)
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS) await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)

View file

@ -43,6 +43,9 @@ class ModelsProtocolPrivate(Protocol):
-> Provider uses provider-model-id for inference -> Provider uses provider-model-id for inference
""" """
# this should be called `on_model_register` or something like that.
# the provider should _not_ be able to change the object in this
# callback
async def register_model(self, model: Model) -> Model: ... async def register_model(self, model: Model) -> Model: ...
async def unregister_model(self, model_id: str) -> None: ... async def unregister_model(self, model_id: str) -> None: ...

View file

@ -146,9 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass pass
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
# Allow any model to be registered as a shield model_id = shield.provider_resource_id
# The model will be validated during runtime when making inference calls if not model_id:
pass raise ValueError("Llama Guard shield must have a model id")
async def run_shield( async def run_shield(
self, self,

View file

@ -15,6 +15,7 @@ class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
MODEL_ENTRIES, MODEL_ENTRIES,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key, api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key", provider_data_api_key_field="anthropic_api_key",
) )

View file

@ -26,7 +26,7 @@ class AnthropicConfig(BaseModel):
) )
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
return { return {
"api_key": api_key, "api_key": api_key,
} }

View file

@ -10,9 +10,9 @@ from llama_stack.providers.utils.inference.model_registry import (
) )
LLM_MODEL_IDS = [ LLM_MODEL_IDS = [
"anthropic/claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest",
"anthropic/claude-3-7-sonnet-latest", "claude-3-7-sonnet-latest",
"anthropic/claude-3-5-haiku-latest", "claude-3-5-haiku-latest",
] ]
SAFETY_MODELS_ENTRIES = [] SAFETY_MODELS_ENTRIES = []
@ -21,17 +21,17 @@ MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [ + [
ProviderModelEntry( ProviderModelEntry(
provider_model_id="anthropic/voyage-3", provider_model_id="voyage-3",
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000}, metadata={"embedding_dimension": 1024, "context_length": 32000},
), ),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="anthropic/voyage-3-lite", provider_model_id="voyage-3-lite",
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata={"embedding_dimension": 512, "context_length": 32000}, metadata={"embedding_dimension": 512, "context_length": 32000},
), ),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="anthropic/voyage-code-3", provider_model_id="voyage-code-3",
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000}, metadata={"embedding_dimension": 1024, "context_length": 32000},
), ),

View file

@ -63,18 +63,20 @@ class BedrockInferenceAdapter(
def __init__(self, config: BedrockConfig) -> None: def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self._config = config self._config = config
self._client = None
self._client = create_bedrock_client(config)
@property @property
def client(self) -> BaseClient: def client(self) -> BaseClient:
if self._client is None:
self._client = create_bedrock_client(self._config)
return self._client return self._client
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client.close() if self._client is not None:
self._client.close()
async def completion( async def completion(
self, self,

View file

@ -65,6 +65,7 @@ class CerebrasInferenceAdapter(
) )
self.config = config self.config = config
# TODO: make this use provider data, etc. like other providers
self.client = AsyncCerebras( self.client = AsyncCerebras(
base_url=self.config.base_url, base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(), api_key=self.config.api_key.get_secret_value(),

View file

@ -26,7 +26,7 @@ class CerebrasImplConfig(BaseModel):
) )
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
return { return {
"base_url": DEFAULT_BASE_URL, "base_url": DEFAULT_BASE_URL,
"api_key": api_key, "api_key": api_key,

View file

@ -25,8 +25,8 @@ class DatabricksImplConfig(BaseModel):
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,
url: str = "${env.DATABRICKS_URL}", url: str = "${env.DATABRICKS_URL:=}",
api_token: str = "${env.DATABRICKS_API_TOKEN}", api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
**kwargs: Any, **kwargs: Any,
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {

View file

@ -24,7 +24,7 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
) )
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
return { return {
"url": "https://api.fireworks.ai/inference/v1", "url": "https://api.fireworks.ai/inference/v1",
"api_key": api_key, "api_key": api_key,

View file

@ -26,7 +26,7 @@ class GeminiConfig(BaseModel):
) )
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
return { return {
"api_key": api_key, "api_key": api_key,
} }

View file

@ -15,6 +15,7 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
MODEL_ENTRIES, MODEL_ENTRIES,
litellm_provider_name="gemini",
api_key_from_config=config.api_key, api_key_from_config=config.api_key,
provider_data_api_key_field="gemini_api_key", provider_data_api_key_field="gemini_api_key",
) )

View file

@ -10,11 +10,11 @@ from llama_stack.providers.utils.inference.model_registry import (
) )
LLM_MODEL_IDS = [ LLM_MODEL_IDS = [
"gemini/gemini-1.5-flash", "gemini-1.5-flash",
"gemini/gemini-1.5-pro", "gemini-1.5-pro",
"gemini/gemini-2.0-flash", "gemini-2.0-flash",
"gemini/gemini-2.5-flash", "gemini-2.5-flash",
"gemini/gemini-2.5-pro", "gemini-2.5-pro",
] ]
SAFETY_MODELS_ENTRIES = [] SAFETY_MODELS_ENTRIES = []
@ -23,7 +23,7 @@ MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [ + [
ProviderModelEntry( ProviderModelEntry(
provider_model_id="gemini/text-embedding-004", provider_model_id="text-embedding-004",
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata={"embedding_dimension": 768, "context_length": 2048}, metadata={"embedding_dimension": 768, "context_length": 2048},
), ),

View file

@ -32,7 +32,7 @@ class GroqConfig(BaseModel):
) )
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
return { return {
"url": "https://api.groq.com", "url": "https://api.groq.com",
"api_key": api_key, "api_key": api_key,

View file

@ -34,6 +34,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
model_entries=MODEL_ENTRIES, model_entries=MODEL_ENTRIES,
litellm_provider_name="groq",
api_key_from_config=config.api_key, api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key", provider_data_api_key_field="groq_api_key",
) )
@ -96,7 +97,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
tool_choice = "required" tool_choice = "required"
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id.replace("groq/", ""), model=model_obj.provider_resource_id,
messages=messages, messages=messages,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
function_call=function_call, function_call=function_call,

View file

@ -14,19 +14,19 @@ SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"groq/llama3-8b-8192", "llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
), ),
build_model_entry( build_model_entry(
"groq/llama-3.1-8b-instant", "llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
), ),
build_hf_repo_model_entry( build_hf_repo_model_entry(
"groq/llama3-70b-8192", "llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value, CoreModelId.llama3_70b_instruct.value,
), ),
build_hf_repo_model_entry( build_hf_repo_model_entry(
"groq/llama-3.3-70b-versatile", "llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
# Groq only contains a preview version for llama-3.2-3b # Groq only contains a preview version for llama-3.2-3b
@ -34,23 +34,15 @@ MODEL_ENTRIES = [
# to pass the test fixture # to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it # TODO(aidand): Replace this with a stable model once Groq supports it
build_hf_repo_model_entry( build_hf_repo_model_entry(
"groq/llama-3.2-3b-preview", "llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value, CoreModelId.llama3_2_3b_instruct.value,
), ),
build_hf_repo_model_entry( build_hf_repo_model_entry(
"groq/llama-4-scout-17b-16e-instruct", "meta-llama/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value, CoreModelId.llama4_scout_17b_16e_instruct.value,
), ),
build_hf_repo_model_entry( build_hf_repo_model_entry(
"groq/meta-llama/llama-4-scout-17b-16e-instruct", "meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"groq/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
build_hf_repo_model_entry(
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value, CoreModelId.llama4_maverick_17b_128e_instruct.value,
), ),
] + SAFETY_MODELS_ENTRIES ] + SAFETY_MODELS_ENTRIES

View file

@ -32,6 +32,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
model_entries=MODEL_ENTRIES, model_entries=MODEL_ENTRIES,
litellm_provider_name="meta_llama",
api_key_from_config=config.api_key, api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key", provider_data_api_key_field="llama_api_key",
openai_compat_api_base=config.openai_compat_api_base, openai_compat_api_base=config.openai_compat_api_base,

View file

@ -166,7 +166,7 @@ class OllamaInferenceAdapter(
] ]
for m in response.models: for m in response.models:
# kill embedding models since we don't know dimensions for them # kill embedding models since we don't know dimensions for them
if m.details.family in ["bert"]: if "bert" in m.details.family:
continue continue
models.append( models.append(
Model( Model(
@ -420,9 +420,6 @@ class OllamaInferenceAdapter(
except ValueError: except ValueError:
pass # Ignore statically unknown model, will check live listing pass # Ignore statically unknown model, will check live listing
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
response = await self.client.list() response = await self.client.list()
if model.provider_resource_id not in [m.model for m in response.models]: if model.provider_resource_id not in [m.model for m in response.models]:
@ -433,9 +430,9 @@ class OllamaInferenceAdapter(
# - models not currently running are run by the ollama server as needed # - models not currently running are run by the ollama server as needed
response = await self.client.list() response = await self.client.list()
available_models = [m.model for m in response.models] available_models = [m.model for m in response.models]
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None:
provider_resource_id = model.provider_resource_id provider_resource_id = model.provider_resource_id
assert provider_resource_id is not None # mypy
if provider_resource_id not in available_models: if provider_resource_id not in available_models:
available_models_latest = [m.model.split(":latest")[0] for m in response.models] available_models_latest = [m.model.split(":latest")[0] for m in response.models]
if provider_resource_id in available_models_latest: if provider_resource_id in available_models_latest:
@ -443,7 +440,9 @@ class OllamaInferenceAdapter(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'" f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
) )
return model return model
raise UnsupportedModelError(model.provider_resource_id, available_models) raise UnsupportedModelError(provider_resource_id, available_models)
# mutating this should be considered an anti-pattern
model.provider_resource_id = provider_resource_id model.provider_resource_id = provider_resource_id
return model return model

View file

@ -24,9 +24,19 @@ class OpenAIConfig(BaseModel):
default=None, default=None,
description="API key for OpenAI models", description="API key for OpenAI models",
) )
base_url: str = Field(
default="https://api.openai.com/v1",
description="Base URL for OpenAI API",
)
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> dict[str, Any]: def sample_run_config(
cls,
api_key: str = "${env.OPENAI_API_KEY:=}",
base_url: str = "${env.OPENAI_BASE_URL:=https://api.openai.com/v1}",
**kwargs,
) -> dict[str, Any]:
return { return {
"api_key": api_key, "api_key": api_key,
"base_url": base_url,
} }

View file

@ -45,6 +45,7 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
MODEL_ENTRIES, MODEL_ENTRIES,
litellm_provider_name="openai",
api_key_from_config=config.api_key, api_key_from_config=config.api_key,
provider_data_api_key_field="openai_api_key", provider_data_api_key_field="openai_api_key",
) )
@ -64,9 +65,9 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
""" """
Get the OpenAI API base URL. Get the OpenAI API base URL.
Returns the standard OpenAI API base URL for direct OpenAI API calls. Returns the OpenAI API base URL from the configuration.
""" """
return "https://api.openai.com/v1" return self.config.base_url
async def initialize(self) -> None: async def initialize(self) -> None:
await super().initialize() await super().initialize()

View file

@ -30,7 +30,7 @@ class SambaNovaImplConfig(BaseModel):
) )
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
return { return {
"url": "https://api.sambanova.ai/v1", "url": "https://api.sambanova.ai/v1",
"api_key": api_key, "api_key": api_key,

View file

@ -9,49 +9,20 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = [ SAFETY_MODELS_ENTRIES = []
build_hf_repo_model_entry(
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.1-8B-Instruct", "Meta-Llama-3.1-8B-Instruct",
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
), ),
build_hf_repo_model_entry( build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.1-405B-Instruct", "Meta-Llama-3.3-70B-Instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.2-1B-Instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.2-3B-Instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
build_hf_repo_model_entry( build_hf_repo_model_entry(
"sambanova/Llama-3.2-11B-Vision-Instruct", "Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-3.2-90B-Vision-Instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value, CoreModelId.llama4_maverick_17b_128e_instruct.value,
), ),
] + SAFETY_MODELS_ENTRIES ] + SAFETY_MODELS_ENTRIES

View file

@ -182,6 +182,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
model_entries=MODEL_ENTRIES, model_entries=MODEL_ENTRIES,
litellm_provider_name="sambanova",
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None, api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
provider_data_api_key_field="sambanova_api_key", provider_data_api_key_field="sambanova_api_key",
) )

View file

@ -19,7 +19,7 @@ class TGIImplConfig(BaseModel):
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,
url: str = "${env.TGI_URL}", url: str = "${env.TGI_URL:=}",
**kwargs, **kwargs,
): ):
return { return {

View file

@ -305,6 +305,8 @@ class _HfAdapter(
class TGIAdapter(_HfAdapter): class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None: async def initialize(self, config: TGIImplConfig) -> None:
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}") log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient( self.client = AsyncInferenceClient(
model=config.url, model=config.url,

View file

@ -27,5 +27,5 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
def sample_run_config(cls, **kwargs) -> dict[str, Any]: def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return { return {
"url": "https://api.together.xyz/v1", "url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY}", "api_key": "${env.TOGETHER_API_KEY:=}",
} }

View file

@ -69,15 +69,9 @@ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"meta-llama/Llama-4-Scout-17B-16E-Instruct", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value, CoreModelId.llama4_scout_17b_16e_instruct.value,
additional_aliases=[
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct",
],
), ),
build_hf_repo_model_entry( build_hf_repo_model_entry(
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value, CoreModelId.llama4_maverick_17b_128e_instruct.value,
additional_aliases=[
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
],
), ),
] + SAFETY_MODELS_ENTRIES ] + SAFETY_MODELS_ENTRIES

View file

@ -299,7 +299,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None self.client = None
async def initialize(self) -> None: async def initialize(self) -> None:
pass if not self.config.url:
raise ValueError(
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
)
async def should_refresh_models(self) -> bool: async def should_refresh_models(self) -> bool:
return self.config.refresh_models return self.config.refresh_models
@ -337,9 +340,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
HealthResponse: A dictionary containing the health status. HealthResponse: A dictionary containing the health status.
""" """
try: try:
if not self.config.url:
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")
client = self._create_client() if self.client is None else self.client client = self._create_client() if self.client is None else self.client
_ = [m async for m in client.models.list()] # Ensure the client is initialized _ = [m async for m in client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK) return HealthResponse(status=HealthStatus.OK)
@ -355,11 +355,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if self.client is not None: if self.client is not None:
return return
if not self.config.url:
raise ValueError(
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
)
log.info(f"Initializing vLLM client with base_url={self.config.url}") log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client() self.client = self._create_client()

View file

@ -30,7 +30,7 @@ class SambaNovaSafetyConfig(BaseModel):
) )
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
return { return {
"url": "https://api.sambanova.ai/v1", "url": "https://api.sambanova.ai/v1",
"api_key": api_key, "api_key": api_key,

View file

@ -68,11 +68,23 @@ class LiteLLMOpenAIMixin(
def __init__( def __init__(
self, self,
model_entries, model_entries,
litellm_provider_name: str,
api_key_from_config: str | None, api_key_from_config: str | None,
provider_data_api_key_field: str, provider_data_api_key_field: str,
openai_compat_api_base: str | None = None, openai_compat_api_base: str | None = None,
): ):
"""
Initialize the LiteLLMOpenAIMixin.
:param model_entries: The model entries to register.
:param api_key_from_config: The API key to use from the config.
:param provider_data_api_key_field: The field in the provider data that contains the API key.
:param litellm_provider_name: The name of the provider, used for model lookups.
:param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
"""
ModelRegistryHelper.__init__(self, model_entries) ModelRegistryHelper.__init__(self, model_entries)
self.litellm_provider_name = litellm_provider_name
self.api_key_from_config = api_key_from_config self.api_key_from_config = api_key_from_config
self.provider_data_api_key_field = provider_data_api_key_field self.provider_data_api_key_field = provider_data_api_key_field
self.api_base = openai_compat_api_base self.api_base = openai_compat_api_base
@ -91,7 +103,11 @@ class LiteLLMOpenAIMixin(
def get_litellm_model_name(self, model_id: str) -> str: def get_litellm_model_name(self, model_id: str) -> str:
# users may be using openai/ prefix in their model names. the openai/models.py did this by default. # users may be using openai/ prefix in their model names. the openai/models.py did this by default.
# model_id.startswith("openai/") is for backwards compatibility. # model_id.startswith("openai/") is for backwards compatibility.
return "openai/" + model_id if self.is_openai_compat and not model_id.startswith("openai/") else model_id return (
f"{self.litellm_provider_name}/{model_id}"
if self.is_openai_compat and not model_id.startswith(self.litellm_provider_name)
else model_id
)
async def completion( async def completion(
self, self,
@ -421,3 +437,17 @@ class LiteLLMOpenAIMixin(
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
): ):
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat") raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available via LiteLLM for the current
provider (self.litellm_provider_name).
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
if self.litellm_provider_name not in litellm.models_by_provider:
logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
return False
return model in litellm.models_by_provider[self.litellm_provider_name]

View file

@ -50,7 +50,8 @@ def build_hf_repo_model_entry(
additional_aliases: list[str] | None = None, additional_aliases: list[str] | None = None,
) -> ProviderModelEntry: ) -> ProviderModelEntry:
aliases = [ aliases = [
get_huggingface_repo(model_descriptor), # NOTE: avoid HF aliases because they _cannot_ be unique across providers
# get_huggingface_repo(model_descriptor),
] ]
if additional_aliases: if additional_aliases:
aliases.extend(additional_aliases) aliases.extend(additional_aliases)
@ -75,7 +76,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
__provider_id__: str __provider_id__: str
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None): def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
self.model_entries = model_entries
self.allowed_models = allowed_models self.allowed_models = allowed_models
self.alias_to_provider_id_map = {} self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {} self.provider_id_to_llama_model_map = {}
for entry in model_entries: for entry in model_entries:
@ -98,7 +101,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
continue continue
models.append( models.append(
Model( Model(
model_id=id, identifier=id,
provider_resource_id=entry.provider_model_id, provider_resource_id=entry.provider_model_id,
model_type=ModelType.llm, model_type=ModelType.llm,
metadata=entry.metadata, metadata=entry.metadata,
@ -185,8 +188,8 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return model return model
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
# TODO: should we block unregistering base supported provider model IDs? # model_id is the identifier, not the provider_resource_id
if model_id not in self.alias_to_provider_id_map: # unfortunately, this ID can be of the form provider_id/model_id which
raise ValueError(f"Model id '{model_id}' is not registered.") # we never registered. TODO: fix this by significantly rewriting
# registration and registry helper
del self.alias_to_provider_id_map[model_id] pass

View file

@ -3,96 +3,50 @@ distribution_spec:
description: CI tests for Llama Stack description: CI tests for Llama Stack
providers: providers:
inference: inference:
- provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} - provider_type: remote::cerebras
provider_type: remote::cerebras - provider_type: remote::ollama
- provider_id: ${env.ENABLE_OLLAMA:=__disabled__} - provider_type: remote::vllm
provider_type: remote::ollama - provider_type: remote::tgi
- provider_id: ${env.ENABLE_VLLM:=__disabled__} - provider_type: remote::fireworks
provider_type: remote::vllm - provider_type: remote::together
- provider_id: ${env.ENABLE_TGI:=__disabled__} - provider_type: remote::bedrock
provider_type: remote::tgi - provider_type: remote::nvidia
- provider_id: ${env.ENABLE_HF_SERVERLESS:=__disabled__} - provider_type: remote::openai
provider_type: remote::hf::serverless - provider_type: remote::anthropic
- provider_id: ${env.ENABLE_HF_ENDPOINT:=__disabled__} - provider_type: remote::gemini
provider_type: remote::hf::endpoint - provider_type: remote::groq
- provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} - provider_type: remote::sambanova
provider_type: remote::fireworks - provider_type: inline::sentence-transformers
- provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_type: remote::together
- provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_type: remote::bedrock
- provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_type: remote::databricks
- provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_type: remote::nvidia
- provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_type: remote::runpod
- provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_type: remote::openai
- provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__}
provider_type: remote::anthropic
- provider_id: ${env.ENABLE_GEMINI:=__disabled__}
provider_type: remote::gemini
- provider_id: ${env.ENABLE_GROQ:=__disabled__}
provider_type: remote::groq
- provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__}
provider_type: remote::llama-openai-compat
- provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
provider_type: remote::sambanova
- provider_id: ${env.ENABLE_PASSTHROUGH:=__disabled__}
provider_type: remote::passthrough
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io: vector_io:
- provider_id: ${env.ENABLE_FAISS:=faiss} - provider_type: inline::faiss
provider_type: inline::faiss - provider_type: inline::sqlite-vec
- provider_id: ${env.ENABLE_SQLITE_VEC:=__disabled__} - provider_type: inline::milvus
provider_type: inline::sqlite-vec - provider_type: remote::chromadb
- provider_id: ${env.ENABLE_MILVUS:=__disabled__} - provider_type: remote::pgvector
provider_type: inline::milvus
- provider_id: ${env.ENABLE_CHROMADB:=__disabled__}
provider_type: remote::chromadb
- provider_id: ${env.ENABLE_PGVECTOR:=__disabled__}
provider_type: remote::pgvector
files: files:
- provider_id: localfs - provider_type: inline::localfs
provider_type: inline::localfs
safety: safety:
- provider_id: llama-guard - provider_type: inline::llama-guard
provider_type: inline::llama-guard
agents: agents:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
post_training: post_training:
- provider_id: huggingface - provider_type: inline::huggingface
provider_type: inline::huggingface
eval: eval:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
datasetio: datasetio:
- provider_id: huggingface - provider_type: remote::huggingface
provider_type: remote::huggingface - provider_type: inline::localfs
- provider_id: localfs
provider_type: inline::localfs
scoring: scoring:
- provider_id: basic - provider_type: inline::basic
provider_type: inline::basic - provider_type: inline::llm-as-judge
- provider_id: llm-as-judge - provider_type: inline::braintrust
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_type: remote::brave-search
provider_type: remote::brave-search - provider_type: remote::tavily-search
- provider_id: tavily-search - provider_type: inline::rag-runtime
provider_type: remote::tavily-search - provider_type: remote::model-context-protocol
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
image_type: conda image_type: conda
image_name: ci-tests image_name: ci-tests
additional_pip_packages: additional_pip_packages:

File diff suppressed because it is too large Load diff

View file

@ -4,48 +4,31 @@ distribution_spec:
container container
providers: providers:
inference: inference:
- provider_id: tgi - provider_type: remote::tgi
provider_type: remote::tgi - provider_type: inline::sentence-transformers
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io: vector_io:
- provider_id: faiss - provider_type: inline::faiss
provider_type: inline::faiss - provider_type: remote::chromadb
- provider_id: chromadb - provider_type: remote::pgvector
provider_type: remote::chromadb
- provider_id: pgvector
provider_type: remote::pgvector
safety: safety:
- provider_id: llama-guard - provider_type: inline::llama-guard
provider_type: inline::llama-guard
agents: agents:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
eval: eval:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
datasetio: datasetio:
- provider_id: huggingface - provider_type: remote::huggingface
provider_type: remote::huggingface - provider_type: inline::localfs
- provider_id: localfs
provider_type: inline::localfs
scoring: scoring:
- provider_id: basic - provider_type: inline::basic
provider_type: inline::basic - provider_type: inline::llm-as-judge
- provider_id: llm-as-judge - provider_type: inline::braintrust
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_type: remote::brave-search
provider_type: remote::brave-search - provider_type: remote::tavily-search
- provider_id: tavily-search - provider_type: inline::rag-runtime
provider_type: remote::tavily-search
- provider_id: rag-runtime
provider_type: inline::rag-runtime
image_type: conda image_type: conda
image_name: dell image_name: dell
additional_pip_packages: additional_pip_packages:

View file

@ -6,6 +6,7 @@
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BuildProvider,
ModelInput, ModelInput,
Provider, Provider,
ShieldInput, ShieldInput,
@ -20,31 +21,31 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
providers = { providers = {
"inference": [ "inference": [
Provider(provider_id="tgi", provider_type="remote::tgi"), BuildProvider(provider_type="remote::tgi"),
Provider(provider_id="sentence-transformers", provider_type="inline::sentence-transformers"), BuildProvider(provider_type="inline::sentence-transformers"),
], ],
"vector_io": [ "vector_io": [
Provider(provider_id="faiss", provider_type="inline::faiss"), BuildProvider(provider_type="inline::faiss"),
Provider(provider_id="chromadb", provider_type="remote::chromadb"), BuildProvider(provider_type="remote::chromadb"),
Provider(provider_id="pgvector", provider_type="remote::pgvector"), BuildProvider(provider_type="remote::pgvector"),
], ],
"safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")], "safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], "agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"eval": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], "eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [ "datasetio": [
Provider(provider_id="huggingface", provider_type="remote::huggingface"), BuildProvider(provider_type="remote::huggingface"),
Provider(provider_id="localfs", provider_type="inline::localfs"), BuildProvider(provider_type="inline::localfs"),
], ],
"scoring": [ "scoring": [
Provider(provider_id="basic", provider_type="inline::basic"), BuildProvider(provider_type="inline::basic"),
Provider(provider_id="llm-as-judge", provider_type="inline::llm-as-judge"), BuildProvider(provider_type="inline::llm-as-judge"),
Provider(provider_id="braintrust", provider_type="inline::braintrust"), BuildProvider(provider_type="inline::braintrust"),
], ],
"tool_runtime": [ "tool_runtime": [
Provider(provider_id="brave-search", provider_type="remote::brave-search"), BuildProvider(provider_type="remote::brave-search"),
Provider(provider_id="tavily-search", provider_type="remote::tavily-search"), BuildProvider(provider_type="remote::tavily-search"),
Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"), BuildProvider(provider_type="inline::rag-runtime"),
], ],
} }
name = "dell" name = "dell"

View file

@ -3,48 +3,31 @@ distribution_spec:
description: Use Meta Reference for running LLM inference description: Use Meta Reference for running LLM inference
providers: providers:
inference: inference:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
vector_io: vector_io:
- provider_id: faiss - provider_type: inline::faiss
provider_type: inline::faiss - provider_type: remote::chromadb
- provider_id: chromadb - provider_type: remote::pgvector
provider_type: remote::chromadb
- provider_id: pgvector
provider_type: remote::pgvector
safety: safety:
- provider_id: llama-guard - provider_type: inline::llama-guard
provider_type: inline::llama-guard
agents: agents:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
eval: eval:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
datasetio: datasetio:
- provider_id: huggingface - provider_type: remote::huggingface
provider_type: remote::huggingface - provider_type: inline::localfs
- provider_id: localfs
provider_type: inline::localfs
scoring: scoring:
- provider_id: basic - provider_type: inline::basic
provider_type: inline::basic - provider_type: inline::llm-as-judge
- provider_id: llm-as-judge - provider_type: inline::braintrust
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_type: remote::brave-search
provider_type: remote::brave-search - provider_type: remote::tavily-search
- provider_id: tavily-search - provider_type: inline::rag-runtime
provider_type: remote::tavily-search - provider_type: remote::model-context-protocol
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
image_type: conda image_type: conda
image_name: meta-reference-gpu image_name: meta-reference-gpu
additional_pip_packages: additional_pip_packages:

View file

@ -8,6 +8,7 @@ from pathlib import Path
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BuildProvider,
ModelInput, ModelInput,
Provider, Provider,
ShieldInput, ShieldInput,
@ -25,91 +26,30 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
providers = { providers = {
"inference": [ "inference": [BuildProvider(provider_type="inline::meta-reference")],
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"vector_io": [ "vector_io": [
Provider( BuildProvider(provider_type="inline::faiss"),
provider_id="faiss", BuildProvider(provider_type="remote::chromadb"),
provider_type="inline::faiss", BuildProvider(provider_type="remote::pgvector"),
),
Provider(
provider_id="chromadb",
provider_type="remote::chromadb",
),
Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
),
],
"safety": [
Provider(
provider_id="llama-guard",
provider_type="inline::llama-guard",
)
],
"agents": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"telemetry": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"eval": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
], ],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [ "datasetio": [
Provider( BuildProvider(provider_type="remote::huggingface"),
provider_id="huggingface", BuildProvider(provider_type="inline::localfs"),
provider_type="remote::huggingface",
),
Provider(
provider_id="localfs",
provider_type="inline::localfs",
),
], ],
"scoring": [ "scoring": [
Provider( BuildProvider(provider_type="inline::basic"),
provider_id="basic", BuildProvider(provider_type="inline::llm-as-judge"),
provider_type="inline::basic", BuildProvider(provider_type="inline::braintrust"),
),
Provider(
provider_id="llm-as-judge",
provider_type="inline::llm-as-judge",
),
Provider(
provider_id="braintrust",
provider_type="inline::braintrust",
),
], ],
"tool_runtime": [ "tool_runtime": [
Provider( BuildProvider(provider_type="remote::brave-search"),
provider_id="brave-search", BuildProvider(provider_type="remote::tavily-search"),
provider_type="remote::brave-search", BuildProvider(provider_type="inline::rag-runtime"),
), BuildProvider(provider_type="remote::model-context-protocol"),
Provider(
provider_id="tavily-search",
provider_type="remote::tavily-search",
),
Provider(
provider_id="rag-runtime",
provider_type="inline::rag-runtime",
),
Provider(
provider_id="model-context-protocol",
provider_type="remote::model-context-protocol",
),
], ],
} }
name = "meta-reference-gpu" name = "meta-reference-gpu"

View file

@ -3,37 +3,26 @@ distribution_spec:
description: Use NVIDIA NIM for running LLM inference, evaluation and safety description: Use NVIDIA NIM for running LLM inference, evaluation and safety
providers: providers:
inference: inference:
- provider_id: nvidia - provider_type: remote::nvidia
provider_type: remote::nvidia
vector_io: vector_io:
- provider_id: faiss - provider_type: inline::faiss
provider_type: inline::faiss
safety: safety:
- provider_id: nvidia - provider_type: remote::nvidia
provider_type: remote::nvidia
agents: agents:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
eval: eval:
- provider_id: nvidia - provider_type: remote::nvidia
provider_type: remote::nvidia
post_training: post_training:
- provider_id: nvidia - provider_type: remote::nvidia
provider_type: remote::nvidia
datasetio: datasetio:
- provider_id: localfs - provider_type: inline::localfs
provider_type: inline::localfs - provider_type: remote::nvidia
- provider_id: nvidia
provider_type: remote::nvidia
scoring: scoring:
- provider_id: basic - provider_type: inline::basic
provider_type: inline::basic
tool_runtime: tool_runtime:
- provider_id: rag-runtime - provider_type: inline::rag-runtime
provider_type: inline::rag-runtime
image_type: conda image_type: conda
image_name: nvidia image_name: nvidia
additional_pip_packages: additional_pip_packages:

View file

@ -6,7 +6,7 @@
from pathlib import Path from pathlib import Path
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput from llama_stack.distribution.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
@ -17,65 +17,19 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
providers = { providers = {
"inference": [ "inference": [BuildProvider(provider_type="remote::nvidia")],
Provider( "vector_io": [BuildProvider(provider_type="inline::faiss")],
provider_id="nvidia", "safety": [BuildProvider(provider_type="remote::nvidia")],
provider_type="remote::nvidia", "agents": [BuildProvider(provider_type="inline::meta-reference")],
) "telemetry": [BuildProvider(provider_type="inline::meta-reference")],
], "eval": [BuildProvider(provider_type="remote::nvidia")],
"vector_io": [ "post_training": [BuildProvider(provider_type="remote::nvidia")],
Provider(
provider_id="faiss",
provider_type="inline::faiss",
)
],
"safety": [
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
)
],
"agents": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"telemetry": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"eval": [
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
)
],
"post_training": [Provider(provider_id="nvidia", provider_type="remote::nvidia", config={})],
"datasetio": [ "datasetio": [
Provider( BuildProvider(provider_type="inline::localfs"),
provider_id="localfs", BuildProvider(provider_type="remote::nvidia"),
provider_type="inline::localfs",
),
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
),
],
"scoring": [
Provider(
provider_id="basic",
provider_type="inline::basic",
)
],
"tool_runtime": [
Provider(
provider_id="rag-runtime",
provider_type="inline::rag-runtime",
)
], ],
"scoring": [BuildProvider(provider_type="inline::basic")],
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")],
} }
inference_provider = Provider( inference_provider = Provider(

View file

@ -89,101 +89,51 @@ models:
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama3-8b-instruct provider_model_id: meta/llama3-8b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3-8B-Instruct
provider_id: nvidia
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta/llama3-70b-instruct model_id: meta/llama3-70b-instruct
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama3-70b-instruct provider_model_id: meta/llama3-70b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta/llama-3.1-8b-instruct model_id: meta/llama-3.1-8b-instruct
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama-3.1-8b-instruct provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta/llama-3.1-70b-instruct model_id: meta/llama-3.1-70b-instruct
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama-3.1-70b-instruct provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta/llama-3.1-405b-instruct model_id: meta/llama-3.1-405b-instruct
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama-3.1-405b-instruct provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: nvidia
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta/llama-3.2-1b-instruct model_id: meta/llama-3.2-1b-instruct
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama-3.2-1b-instruct provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta/llama-3.2-3b-instruct model_id: meta/llama-3.2-3b-instruct
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama-3.2-3b-instruct provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta/llama-3.2-11b-vision-instruct model_id: meta/llama-3.2-11b-vision-instruct
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama-3.2-11b-vision-instruct provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta/llama-3.2-90b-vision-instruct model_id: meta/llama-3.2-90b-vision-instruct
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama-3.2-90b-vision-instruct provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta/llama-3.3-70b-instruct model_id: meta/llama-3.3-70b-instruct
provider_id: nvidia provider_id: nvidia
provider_model_id: meta/llama-3.3-70b-instruct provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: nvidia
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata: - metadata:
embedding_dimension: 2048 embedding_dimension: 2048
context_length: 8192 context_length: 8192

View file

@ -3,56 +3,35 @@ distribution_spec:
description: Distribution for running open benchmarks description: Distribution for running open benchmarks
providers: providers:
inference: inference:
- provider_id: openai - provider_type: remote::openai
provider_type: remote::openai - provider_type: remote::anthropic
- provider_id: anthropic - provider_type: remote::gemini
provider_type: remote::anthropic - provider_type: remote::groq
- provider_id: gemini - provider_type: remote::together
provider_type: remote::gemini
- provider_id: groq
provider_type: remote::groq
- provider_id: together
provider_type: remote::together
vector_io: vector_io:
- provider_id: sqlite-vec - provider_type: inline::sqlite-vec
provider_type: inline::sqlite-vec - provider_type: remote::chromadb
- provider_id: chromadb - provider_type: remote::pgvector
provider_type: remote::chromadb
- provider_id: pgvector
provider_type: remote::pgvector
safety: safety:
- provider_id: llama-guard - provider_type: inline::llama-guard
provider_type: inline::llama-guard
agents: agents:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
eval: eval:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
datasetio: datasetio:
- provider_id: huggingface - provider_type: remote::huggingface
provider_type: remote::huggingface - provider_type: inline::localfs
- provider_id: localfs
provider_type: inline::localfs
scoring: scoring:
- provider_id: basic - provider_type: inline::basic
provider_type: inline::basic - provider_type: inline::llm-as-judge
- provider_id: llm-as-judge - provider_type: inline::braintrust
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_type: remote::brave-search
provider_type: remote::brave-search - provider_type: remote::tavily-search
- provider_id: tavily-search - provider_type: inline::rag-runtime
provider_type: remote::tavily-search - provider_type: remote::model-context-protocol
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
image_type: conda image_type: conda
image_name: open-benchmark image_name: open-benchmark
additional_pip_packages: additional_pip_packages:

View file

@ -9,6 +9,7 @@ from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BenchmarkInput, BenchmarkInput,
BuildProvider,
DatasetInput, DatasetInput,
ModelInput, ModelInput,
Provider, Provider,
@ -96,33 +97,30 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
inference_providers, available_models = get_inference_providers() inference_providers, available_models = get_inference_providers()
providers = { providers = {
"inference": inference_providers, "inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in inference_providers],
"vector_io": [ "vector_io": [
Provider(provider_id="sqlite-vec", provider_type="inline::sqlite-vec"), BuildProvider(provider_type="inline::sqlite-vec"),
Provider(provider_id="chromadb", provider_type="remote::chromadb"), BuildProvider(provider_type="remote::chromadb"),
Provider(provider_id="pgvector", provider_type="remote::pgvector"), BuildProvider(provider_type="remote::pgvector"),
], ],
"safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")], "safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], "agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"eval": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], "eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [ "datasetio": [
Provider(provider_id="huggingface", provider_type="remote::huggingface"), BuildProvider(provider_type="remote::huggingface"),
Provider(provider_id="localfs", provider_type="inline::localfs"), BuildProvider(provider_type="inline::localfs"),
], ],
"scoring": [ "scoring": [
Provider(provider_id="basic", provider_type="inline::basic"), BuildProvider(provider_type="inline::basic"),
Provider(provider_id="llm-as-judge", provider_type="inline::llm-as-judge"), BuildProvider(provider_type="inline::llm-as-judge"),
Provider(provider_id="braintrust", provider_type="inline::braintrust"), BuildProvider(provider_type="inline::braintrust"),
], ],
"tool_runtime": [ "tool_runtime": [
Provider(provider_id="brave-search", provider_type="remote::brave-search"), BuildProvider(provider_type="remote::brave-search"),
Provider(provider_id="tavily-search", provider_type="remote::tavily-search"), BuildProvider(provider_type="remote::tavily-search"),
Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"), BuildProvider(provider_type="inline::rag-runtime"),
Provider( BuildProvider(provider_type="remote::model-context-protocol"),
provider_id="model-context-protocol",
provider_type="remote::model-context-protocol",
),
], ],
} }
name = "open-benchmark" name = "open-benchmark"

View file

@ -16,6 +16,7 @@ providers:
provider_type: remote::openai provider_type: remote::openai
config: config:
api_key: ${env.OPENAI_API_KEY:=} api_key: ${env.OPENAI_API_KEY:=}
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
- provider_id: anthropic - provider_id: anthropic
provider_type: remote::anthropic provider_type: remote::anthropic
config: config:
@ -33,7 +34,7 @@ providers:
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY} api_key: ${env.TOGETHER_API_KEY:=}
vector_io: vector_io:
- provider_id: sqlite-vec - provider_id: sqlite-vec
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec

View file

@ -3,31 +3,21 @@ distribution_spec:
description: Quick start template for running Llama Stack with several popular providers description: Quick start template for running Llama Stack with several popular providers
providers: providers:
inference: inference:
- provider_id: vllm-inference - provider_type: remote::vllm
provider_type: remote::vllm - provider_type: inline::sentence-transformers
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io: vector_io:
- provider_id: chromadb - provider_type: remote::chromadb
provider_type: remote::chromadb
safety: safety:
- provider_id: llama-guard - provider_type: inline::llama-guard
provider_type: inline::llama-guard
agents: agents:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_type: remote::brave-search
provider_type: remote::brave-search - provider_type: remote::tavily-search
- provider_id: tavily-search - provider_type: inline::rag-runtime
provider_type: remote::tavily-search - provider_type: remote::model-context-protocol
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
image_type: conda image_type: conda
image_name: postgres-demo image_name: postgres-demo
additional_pip_packages: additional_pip_packages:

View file

@ -7,6 +7,7 @@
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BuildProvider,
ModelInput, ModelInput,
Provider, Provider,
ShieldInput, ShieldInput,
@ -34,24 +35,19 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
providers = { providers = {
"inference": inference_providers "inference": [
+ [ BuildProvider(provider_type="remote::vllm"),
Provider(provider_id="sentence-transformers", provider_type="inline::sentence-transformers"), BuildProvider(provider_type="inline::sentence-transformers"),
], ],
"vector_io": [ "vector_io": [BuildProvider(provider_type="remote::chromadb")],
Provider(provider_id="chromadb", provider_type="remote::chromadb"), "safety": [BuildProvider(provider_type="inline::llama-guard")],
], "agents": [BuildProvider(provider_type="inline::meta-reference")],
"safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
"telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
"tool_runtime": [ "tool_runtime": [
Provider(provider_id="brave-search", provider_type="remote::brave-search"), BuildProvider(provider_type="remote::brave-search"),
Provider(provider_id="tavily-search", provider_type="remote::tavily-search"), BuildProvider(provider_type="remote::tavily-search"),
Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"), BuildProvider(provider_type="inline::rag-runtime"),
Provider( BuildProvider(provider_type="remote::model-context-protocol"),
provider_id="model-context-protocol",
provider_type="remote::model-context-protocol",
),
], ],
} }
name = "postgres-demo" name = "postgres-demo"

View file

@ -3,96 +3,50 @@ distribution_spec:
description: Quick start template for running Llama Stack with several popular providers description: Quick start template for running Llama Stack with several popular providers
providers: providers:
inference: inference:
- provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} - provider_type: remote::cerebras
provider_type: remote::cerebras - provider_type: remote::ollama
- provider_id: ${env.ENABLE_OLLAMA:=__disabled__} - provider_type: remote::vllm
provider_type: remote::ollama - provider_type: remote::tgi
- provider_id: ${env.ENABLE_VLLM:=__disabled__} - provider_type: remote::fireworks
provider_type: remote::vllm - provider_type: remote::together
- provider_id: ${env.ENABLE_TGI:=__disabled__} - provider_type: remote::bedrock
provider_type: remote::tgi - provider_type: remote::nvidia
- provider_id: ${env.ENABLE_HF_SERVERLESS:=__disabled__} - provider_type: remote::openai
provider_type: remote::hf::serverless - provider_type: remote::anthropic
- provider_id: ${env.ENABLE_HF_ENDPOINT:=__disabled__} - provider_type: remote::gemini
provider_type: remote::hf::endpoint - provider_type: remote::groq
- provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} - provider_type: remote::sambanova
provider_type: remote::fireworks - provider_type: inline::sentence-transformers
- provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_type: remote::together
- provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_type: remote::bedrock
- provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_type: remote::databricks
- provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_type: remote::nvidia
- provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_type: remote::runpod
- provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_type: remote::openai
- provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__}
provider_type: remote::anthropic
- provider_id: ${env.ENABLE_GEMINI:=__disabled__}
provider_type: remote::gemini
- provider_id: ${env.ENABLE_GROQ:=__disabled__}
provider_type: remote::groq
- provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__}
provider_type: remote::llama-openai-compat
- provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
provider_type: remote::sambanova
- provider_id: ${env.ENABLE_PASSTHROUGH:=__disabled__}
provider_type: remote::passthrough
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io: vector_io:
- provider_id: ${env.ENABLE_FAISS:=faiss} - provider_type: inline::faiss
provider_type: inline::faiss - provider_type: inline::sqlite-vec
- provider_id: ${env.ENABLE_SQLITE_VEC:=__disabled__} - provider_type: inline::milvus
provider_type: inline::sqlite-vec - provider_type: remote::chromadb
- provider_id: ${env.ENABLE_MILVUS:=__disabled__} - provider_type: remote::pgvector
provider_type: inline::milvus
- provider_id: ${env.ENABLE_CHROMADB:=__disabled__}
provider_type: remote::chromadb
- provider_id: ${env.ENABLE_PGVECTOR:=__disabled__}
provider_type: remote::pgvector
files: files:
- provider_id: localfs - provider_type: inline::localfs
provider_type: inline::localfs
safety: safety:
- provider_id: llama-guard - provider_type: inline::llama-guard
provider_type: inline::llama-guard
agents: agents:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
post_training: post_training:
- provider_id: huggingface - provider_type: inline::huggingface
provider_type: inline::huggingface
eval: eval:
- provider_id: meta-reference - provider_type: inline::meta-reference
provider_type: inline::meta-reference
datasetio: datasetio:
- provider_id: huggingface - provider_type: remote::huggingface
provider_type: remote::huggingface - provider_type: inline::localfs
- provider_id: localfs
provider_type: inline::localfs
scoring: scoring:
- provider_id: basic - provider_type: inline::basic
provider_type: inline::basic - provider_type: inline::llm-as-judge
- provider_id: llm-as-judge - provider_type: inline::braintrust
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
tool_runtime: tool_runtime:
- provider_id: brave-search - provider_type: remote::brave-search
provider_type: remote::brave-search - provider_type: remote::tavily-search
- provider_id: tavily-search - provider_type: inline::rag-runtime
provider_type: remote::tavily-search - provider_type: remote::model-context-protocol
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
image_type: conda image_type: conda
image_name: starter image_name: starter
additional_pip_packages: additional_pip_packages:

File diff suppressed because it is too large Load diff

View file

@ -7,19 +7,19 @@
from typing import Any from typing import Any
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
ModelInput, BuildProvider,
Provider, Provider,
ProviderSpec, ProviderSpec,
ShieldInput,
ToolGroupInput, ToolGroupInput,
) )
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import RemoteProviderSpec
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import ( from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig, SentenceTransformersInferenceConfig,
) )
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.milvus.config import ( from llama_stack.providers.inline.vector_io.milvus.config import (
MilvusVectorIOConfig, MilvusVectorIOConfig,
@ -28,117 +28,17 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
SQLiteVectorIOConfig, SQLiteVectorIOConfig,
) )
from llama_stack.providers.registry.inference import available_providers from llama_stack.providers.registry.inference import available_providers
from llama_stack.providers.remote.inference.anthropic.models import (
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.bedrock.models import (
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.cerebras.models import (
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.databricks.databricks import (
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.fireworks.models import (
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.gemini.models import (
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.groq.models import (
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.nvidia.models import (
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.openai.models import (
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.runpod.runpod import (
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.sambanova.models import (
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.together.models import (
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
)
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import ( from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig, PGVectorVectorIOConfig,
) )
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
from llama_stack.templates.template import ( from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
RunConfigSettings, RunConfigSettings,
get_model_registry,
get_shield_registry,
) )
def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
"""Get model entries for a specific provider type."""
model_entries_map = {
"openai": OPENAI_MODEL_ENTRIES,
"fireworks": FIREWORKS_MODEL_ENTRIES,
"together": TOGETHER_MODEL_ENTRIES,
"anthropic": ANTHROPIC_MODEL_ENTRIES,
"gemini": GEMINI_MODEL_ENTRIES,
"groq": GROQ_MODEL_ENTRIES,
"sambanova": SAMBANOVA_MODEL_ENTRIES,
"cerebras": CEREBRAS_MODEL_ENTRIES,
"bedrock": BEDROCK_MODEL_ENTRIES,
"databricks": DATABRICKS_MODEL_ENTRIES,
"nvidia": NVIDIA_MODEL_ENTRIES,
"runpod": RUNPOD_MODEL_ENTRIES,
}
# Special handling for providers with dynamic model entries
if provider_type == "ollama":
return [
ProviderModelEntry(
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
ProviderModelEntry(
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": "${env.OLLAMA_EMBEDDING_DIMENSION:=384}",
},
),
]
elif provider_type == "vllm":
return [
ProviderModelEntry(
provider_model_id="${env.VLLM_INFERENCE_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
]
return model_entries_map.get(provider_type, [])
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
"""Get model entries for a specific provider type."""
safety_model_entries_map = {
"ollama": [
ProviderModelEntry(
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
],
}
return safety_model_entries_map.get(provider_type, [])
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]: def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
"""Get configuration for a provider using its adapter's config class.""" """Get configuration for a provider using its adapter's config class."""
config_class = instantiate_class_type(provider_spec.config_class) config_class = instantiate_class_type(provider_spec.config_class)
@ -149,40 +49,48 @@ def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
return {} return {}
def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]: ENABLED_INFERENCE_PROVIDERS = [
all_providers = available_providers() "ollama",
"vllm",
# Filter out inline providers and watsonx - the starter distro only exposes remote providers "tgi",
remote_providers = [ "fireworks",
provider "together",
for provider in all_providers "gemini",
# TODO: re-add once the Python 3.13 issue is fixed "groq",
# discussion: https://github.com/meta-llama/llama-stack/pull/2327#discussion_r2156883828 "sambanova",
if hasattr(provider, "adapter") and provider.adapter.adapter_type != "watsonx" "anthropic",
"openai",
"cerebras",
"nvidia",
"bedrock",
] ]
providers = [] INFERENCE_PROVIDER_IDS = {
available_models = {} "vllm": "${env.VLLM_URL:+vllm}",
"tgi": "${env.TGI_URL:+tgi}",
"cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
"nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
}
def get_remote_inference_providers() -> list[Provider]:
# Filter out inline providers and some others - the starter distro only exposes remote providers
remote_providers = [
provider
for provider in available_providers()
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
]
inference_providers = []
for provider_spec in remote_providers: for provider_spec in remote_providers:
provider_type = provider_spec.adapter.adapter_type provider_type = provider_spec.adapter.adapter_type
# Build the environment variable name for enabling this provider if provider_type in INFERENCE_PROVIDER_IDS:
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}" provider_id = INFERENCE_PROVIDER_IDS[provider_type]
model_entries = _get_model_entries_for_provider(provider_type) else:
provider_id = provider_type.replace("-", "_").replace("::", "_")
config = _get_config_for_provider(provider_spec) config = _get_config_for_provider(provider_spec)
providers.append(
(
f"${{env.{env_var}:=__disabled__}}",
provider_type,
model_entries,
config,
)
)
available_models[f"${{env.{env_var}:=__disabled__}}"] = model_entries
inference_providers = []
for provider_id, provider_type, model_entries, config in providers:
inference_providers.append( inference_providers.append(
Provider( Provider(
provider_id=provider_id, provider_id=provider_id,
@ -190,154 +98,43 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
config=config, config=config,
) )
) )
available_models[provider_id] = model_entries return inference_providers
return inference_providers, available_models
# build a list of shields for all possible providers
def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
available_models = {}
for provider in providers:
provider_type = provider.provider_type.split("::")[1]
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
if len(safety_model_entries) == 0:
continue
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
provider_id = f"${{env.{env_var}:=__disabled__}}"
available_models[provider_id] = safety_model_entries
return available_models
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
remote_inference_providers, available_models = get_remote_inference_providers() remote_inference_providers = get_remote_inference_providers()
name = "starter" name = "starter"
vector_io_providers = [
Provider(
provider_id="${env.ENABLE_FAISS:=faiss}",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_SQLITE_VEC:=__disabled__}",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_MILVUS:=__disabled__}",
provider_type="inline::milvus",
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_CHROMADB:=__disabled__}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}/",
url="${env.CHROMADB_URL:=}",
),
),
Provider(
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}",
),
),
]
providers = { providers = {
"inference": remote_inference_providers "inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
+ [ + [BuildProvider(provider_type="inline::sentence-transformers")],
Provider( "vector_io": [
provider_id="sentence-transformers", BuildProvider(provider_type="inline::faiss"),
provider_type="inline::sentence-transformers", BuildProvider(provider_type="inline::sqlite-vec"),
) BuildProvider(provider_type="inline::milvus"),
], BuildProvider(provider_type="remote::chromadb"),
"vector_io": vector_io_providers, BuildProvider(provider_type="remote::pgvector"),
"files": [
Provider(
provider_id="localfs",
provider_type="inline::localfs",
)
],
"safety": [
Provider(
provider_id="llama-guard",
provider_type="inline::llama-guard",
)
],
"agents": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"telemetry": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"post_training": [
Provider(
provider_id="huggingface",
provider_type="inline::huggingface",
)
],
"eval": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
], ],
"files": [BuildProvider(provider_type="inline::localfs")],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"post_training": [BuildProvider(provider_type="inline::huggingface")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [ "datasetio": [
Provider( BuildProvider(provider_type="remote::huggingface"),
provider_id="huggingface", BuildProvider(provider_type="inline::localfs"),
provider_type="remote::huggingface",
),
Provider(
provider_id="localfs",
provider_type="inline::localfs",
),
], ],
"scoring": [ "scoring": [
Provider( BuildProvider(provider_type="inline::basic"),
provider_id="basic", BuildProvider(provider_type="inline::llm-as-judge"),
provider_type="inline::basic", BuildProvider(provider_type="inline::braintrust"),
),
Provider(
provider_id="llm-as-judge",
provider_type="inline::llm-as-judge",
),
Provider(
provider_id="braintrust",
provider_type="inline::braintrust",
),
], ],
"tool_runtime": [ "tool_runtime": [
Provider( BuildProvider(provider_type="remote::brave-search"),
provider_id="brave-search", BuildProvider(provider_type="remote::tavily-search"),
provider_type="remote::brave-search", BuildProvider(provider_type="inline::rag-runtime"),
), BuildProvider(provider_type="remote::model-context-protocol"),
Provider(
provider_id="tavily-search",
provider_type="remote::tavily-search",
),
Provider(
provider_id="rag-runtime",
provider_type="inline::rag-runtime",
),
Provider(
provider_id="model-context-protocol",
provider_type="remote::model-context-protocol",
),
], ],
} }
files_provider = Provider( files_provider = Provider(
@ -346,15 +143,10 @@ def get_distribution_template() -> DistributionTemplate:
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"), config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
) )
embedding_provider = Provider( embedding_provider = Provider(
provider_id="${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}", provider_id="sentence-transformers",
provider_type="inline::sentence-transformers", provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(), config=SentenceTransformersInferenceConfig.sample_run_config(),
) )
post_training_provider = Provider(
provider_id="huggingface",
provider_type="inline::huggingface",
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_tool_groups = [ default_tool_groups = [
ToolGroupInput( ToolGroupInput(
toolgroup_id="builtin::websearch", toolgroup_id="builtin::websearch",
@ -365,19 +157,14 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="rag-runtime", provider_id="rag-runtime",
), ),
] ]
embedding_model = ModelInput( default_shields = [
model_id="all-MiniLM-L6-v2", # if the
provider_id=embedding_provider.provider_id, ShieldInput(
model_type=ModelType.embedding, shield_id="llama-guard",
metadata={ provider_id="${env.SAFETY_MODEL:+llama-guard}",
"embedding_dimension": 384, provider_shield_id="${env.SAFETY_MODEL:=}",
}, ),
) ]
default_models, ids_conflict_in_models = get_model_registry(available_models)
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,
@ -386,20 +173,51 @@ def get_distribution_template() -> DistributionTemplate:
container_image=None, container_image=None,
template_path=None, template_path=None,
providers=providers, providers=providers,
available_models_by_provider=available_models,
additional_pip_packages=PostgresSqlStoreConfig.pip_packages(), additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
run_configs={ run_configs={
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": remote_inference_providers + [embedding_provider], "inference": remote_inference_providers + [embedding_provider],
"vector_io": vector_io_providers, "vector_io": [
Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.MILVUS_URL:+milvus}",
provider_type="inline::milvus",
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.CHROMADB_URL:+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}/",
url="${env.CHROMADB_URL:=}",
),
),
Provider(
provider_id="${env.PGVECTOR_DB:+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}",
),
),
],
"files": [files_provider], "files": [files_provider],
"post_training": [post_training_provider],
}, },
default_models=[embedding_model] + default_models, default_models=[],
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
# TODO: add a way to enable/disable shields on the fly default_shields=default_shields,
default_shields=shields,
), ),
}, },
run_config_env_vars={ run_config_env_vars={
@ -443,17 +261,5 @@ def get_distribution_template() -> DistributionTemplate:
"http://localhost:11434", "http://localhost:11434",
"Ollama URL", "Ollama URL",
), ),
"OLLAMA_INFERENCE_MODEL": (
"",
"Optional Ollama Inference Model to register on startup",
),
"OLLAMA_EMBEDDING_MODEL": (
"",
"Optional Ollama Embedding Model to register on startup",
),
"OLLAMA_EMBEDDING_DIMENSION": (
"384",
"Ollama Embedding Dimension",
),
}, },
) )

View file

@ -19,6 +19,7 @@ from llama_stack.distribution.datatypes import (
Api, Api,
BenchmarkInput, BenchmarkInput,
BuildConfig, BuildConfig,
BuildProvider,
DatasetInput, DatasetInput,
DistributionSpec, DistributionSpec,
ModelInput, ModelInput,
@ -183,7 +184,7 @@ class RunConfigSettings(BaseModel):
def run_config( def run_config(
self, self,
name: str, name: str,
providers: dict[str, list[Provider]], providers: dict[str, list[BuildProvider]],
container_image: str | None = None, container_image: str | None = None,
) -> dict: ) -> dict:
provider_registry = get_provider_registry() provider_registry = get_provider_registry()
@ -199,7 +200,7 @@ class RunConfigSettings(BaseModel):
api = Api(api_str) api = Api(api_str)
if provider.provider_type not in provider_registry[api]: if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Unknown provider type: {provider.provider_type} for API: {api_str}") raise ValueError(f"Unknown provider type: {provider.provider_type} for API: {api_str}")
provider_id = provider.provider_type.split("::")[-1]
config_class = provider_registry[api][provider.provider_type].config_class config_class = provider_registry[api][provider.provider_type].config_class
assert config_class is not None, ( assert config_class is not None, (
f"No config class for provider type: {provider.provider_type} for API: {api_str}" f"No config class for provider type: {provider.provider_type} for API: {api_str}"
@ -210,10 +211,14 @@ class RunConfigSettings(BaseModel):
config = config_class.sample_run_config(__distro_dir__=f"~/.llama/distributions/{name}") config = config_class.sample_run_config(__distro_dir__=f"~/.llama/distributions/{name}")
else: else:
config = {} config = {}
# BuildProvider does not have a config attribute; skip assignment
provider.config = config provider_configs[api_str].append(
# Convert Provider object to dict for YAML serialization Provider(
provider_configs[api_str].append(provider.model_dump(exclude_none=True)) provider_id=provider_id,
provider_type=provider.provider_type,
config=config,
).model_dump(exclude_none=True)
)
# Get unique set of APIs from providers # Get unique set of APIs from providers
apis = sorted(providers.keys()) apis = sorted(providers.keys())
@ -257,7 +262,8 @@ class DistributionTemplate(BaseModel):
description: str description: str
distro_type: Literal["self_hosted", "remote_hosted", "ondevice"] distro_type: Literal["self_hosted", "remote_hosted", "ondevice"]
providers: dict[str, list[Provider]] # Now uses BuildProvider for build config, not Provider
providers: dict[str, list[BuildProvider]]
run_configs: dict[str, RunConfigSettings] run_configs: dict[str, RunConfigSettings]
template_path: Path | None = None template_path: Path | None = None
@ -295,11 +301,9 @@ class DistributionTemplate(BaseModel):
for api, providers in self.providers.items(): for api, providers in self.providers.items():
build_providers[api] = [] build_providers[api] = []
for provider in providers: for provider in providers:
# Create a minimal provider object with only essential build information # Create a minimal build provider object with only essential build information
build_provider = Provider( build_provider = BuildProvider(
provider_id=provider.provider_id,
provider_type=provider.provider_type, provider_type=provider.provider_type,
config={}, # Empty config for build
module=provider.module, module=provider.module,
) )
build_providers[api].append(build_provider) build_providers[api].append(build_provider)
@ -323,6 +327,7 @@ class DistributionTemplate(BaseModel):
providers_str = ", ".join(f"`{p.provider_type}`" for p in providers) providers_str = ", ".join(f"`{p.provider_type}`" for p in providers)
providers_table += f"| {api} | {providers_str} |\n" providers_table += f"| {api} | {providers_str} |\n"
if self.template_path is not None:
template = self.template_path.read_text() template = self.template_path.read_text()
comment = "<!-- This file was auto-generated by distro_codegen.py, please edit source -->\n" comment = "<!-- This file was auto-generated by distro_codegen.py, please edit source -->\n"
orphantext = "---\norphan: true\n---\n" orphantext = "---\norphan: true\n---\n"
@ -367,6 +372,7 @@ class DistributionTemplate(BaseModel):
run_config_env_vars=self.run_config_env_vars, run_config_env_vars=self.run_config_env_vars,
default_models=default_models, default_models=default_models,
) )
return ""
def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None: def save_distribution(self, yaml_output_dir: Path, doc_output_dir: Path) -> None:
def enum_representer(dumper, data): def enum_representer(dumper, data):

View file

@ -7,7 +7,7 @@
from pathlib import Path from pathlib import Path
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput from llama_stack.distribution.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
from llama_stack.providers.inline.inference.sentence_transformers import ( from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig, SentenceTransformersInferenceConfig,
) )
@ -19,86 +19,28 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
providers = { providers = {
"inference": [ "inference": [
Provider( BuildProvider(provider_type="remote::watsonx"),
provider_id="watsonx", BuildProvider(provider_type="inline::sentence-transformers"),
provider_type="remote::watsonx",
),
Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
),
],
"vector_io": [
Provider(
provider_id="faiss",
provider_type="inline::faiss",
)
],
"safety": [
Provider(
provider_id="llama-guard",
provider_type="inline::llama-guard",
)
],
"agents": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"telemetry": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"eval": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
], ],
"vector_io": [BuildProvider(provider_type="inline::faiss")],
"safety": [BuildProvider(provider_type="inline::llama-guard")],
"agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"eval": [BuildProvider(provider_type="inline::meta-reference")],
"datasetio": [ "datasetio": [
Provider( BuildProvider(provider_type="remote::huggingface"),
provider_id="huggingface", BuildProvider(provider_type="inline::localfs"),
provider_type="remote::huggingface",
),
Provider(
provider_id="localfs",
provider_type="inline::localfs",
),
], ],
"scoring": [ "scoring": [
Provider( BuildProvider(provider_type="inline::basic"),
provider_id="basic", BuildProvider(provider_type="inline::llm-as-judge"),
provider_type="inline::basic", BuildProvider(provider_type="inline::braintrust"),
),
Provider(
provider_id="llm-as-judge",
provider_type="inline::llm-as-judge",
),
Provider(
provider_id="braintrust",
provider_type="inline::braintrust",
),
], ],
"tool_runtime": [ "tool_runtime": [
Provider( BuildProvider(provider_type="remote::brave-search"),
provider_id="brave-search", BuildProvider(provider_type="remote::tavily-search"),
provider_type="remote::brave-search", BuildProvider(provider_type="inline::rag-runtime"),
), BuildProvider(provider_type="remote::model-context-protocol"),
Provider(
provider_id="tavily-search",
provider_type="remote::tavily-search",
),
Provider(
provider_id="rag-runtime",
provider_type="inline::rag-runtime",
),
Provider(
provider_id="model-context-protocol",
provider_type="remote::model-context-protocol",
),
], ],
} }

View file

@ -20,7 +20,7 @@
"@radix-ui/react-tooltip": "^1.2.6", "@radix-ui/react-tooltip": "^1.2.6",
"class-variance-authority": "^0.7.1", "class-variance-authority": "^0.7.1",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"llama-stack-client": "^0.2.15", "llama-stack-client": ""0.2.16",
"lucide-react": "^0.510.0", "lucide-react": "^0.510.0",
"next": "15.3.3", "next": "15.3.3",
"next-auth": "^4.24.11", "next-auth": "^4.24.11",

View file

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "llama_stack" name = "llama_stack"
version = "0.2.15" version = "0.2.16"
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }] authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
description = "Llama Stack" description = "Llama Stack"
readme = "README.md" readme = "README.md"
@ -25,10 +25,10 @@ dependencies = [
"fastapi>=0.115.0,<1.0", # server "fastapi>=0.115.0,<1.0", # server
"fire", # for MCP in LLS client "fire", # for MCP in LLS client
"httpx", "httpx",
"huggingface-hub>=0.30.0,<1.0", "huggingface-hub>=0.34.0,<1.0",
"jinja2>=3.1.6", "jinja2>=3.1.6",
"jsonschema", "jsonschema",
"llama-stack-client>=0.2.15", "llama-stack-client>=0.2.16",
"llama-api-client>=0.1.2", "llama-api-client>=0.1.2",
"openai>=1.66", "openai>=1.66",
"prompt-toolkit", "prompt-toolkit",
@ -53,7 +53,7 @@ dependencies = [
ui = [ ui = [
"streamlit", "streamlit",
"pandas", "pandas",
"llama-stack-client>=0.2.15", "llama-stack-client>=0.2.16",
"streamlit-option-menu", "streamlit-option-menu",
] ]
@ -114,6 +114,7 @@ test = [
"sqlalchemy[asyncio]>=2.0.41", "sqlalchemy[asyncio]>=2.0.41",
"requests", "requests",
"pymilvus>=2.5.12", "pymilvus>=2.5.12",
"reportlab",
] ]
docs = [ docs = [
"setuptools", "setuptools",

View file

@ -86,7 +86,7 @@ httpx==0.28.1
# llama-stack # llama-stack
# llama-stack-client # llama-stack-client
# openai # openai
huggingface-hub==0.33.0 huggingface-hub==0.34.1
# via llama-stack # via llama-stack
idna==3.10 idna==3.10
# via # via
@ -106,7 +106,7 @@ jsonschema-specifications==2024.10.1
# via jsonschema # via jsonschema
llama-api-client==0.1.2 llama-api-client==0.1.2
# via llama-stack # via llama-stack
llama-stack-client==0.2.15 llama-stack-client==0.2.16
# via llama-stack # via llama-stack
markdown-it-py==3.0.0 markdown-it-py==3.0.0
# via rich # via rich
@ -167,14 +167,14 @@ pyasn1==0.4.8
# rsa # rsa
pycparser==2.22 ; platform_python_implementation != 'PyPy' pycparser==2.22 ; platform_python_implementation != 'PyPy'
# via cffi # via cffi
pydantic==2.10.6 pydantic==2.11.7
# via # via
# fastapi # fastapi
# llama-api-client # llama-api-client
# llama-stack # llama-stack
# llama-stack-client # llama-stack-client
# openai # openai
pydantic-core==2.27.2 pydantic-core==2.33.2
# via pydantic # via pydantic
pygments==2.19.1 pygments==2.19.1
# via rich # via rich
@ -253,6 +253,9 @@ typing-extensions==4.12.2
# pydantic # pydantic
# pydantic-core # pydantic-core
# referencing # referencing
# typing-inspection
typing-inspection==0.4.1
# via pydantic
tzdata==2025.1 tzdata==2025.1
# via pandas # via pandas
urllib3==2.5.0 urllib3==2.5.0

View file

@ -222,9 +222,7 @@ cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
--network llama-net \ --network llama-net \
-p "${PORT}:${PORT}" \ -p "${PORT}:${PORT}" \
"${SERVER_IMAGE}" --port "${PORT}" \ "${SERVER_IMAGE}" --port "${PORT}" \
--env OLLAMA_INFERENCE_MODEL="${MODEL_ALIAS}" \ --env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}")
--env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" \
--env ENABLE_OLLAMA=ollama)
log "🦙 Starting Llama Stack..." log "🦙 Starting Llama Stack..."
if ! execute_with_log $ENGINE "${cmd[@]}"; then if ! execute_with_log $ENGINE "${cmd[@]}"; then

View file

@ -8,6 +8,15 @@
PYTHON_VERSION=${PYTHON_VERSION:-3.12} PYTHON_VERSION=${PYTHON_VERSION:-3.12}
set -e
# Always run this at the end, even if something fails
cleanup() {
echo "Generating coverage report..."
uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION
}
trap cleanup EXIT
command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; } command -v uv >/dev/null 2>&1 || { echo >&2 "uv is required but it's not installed. Exiting."; exit 1; }
uv python find "$PYTHON_VERSION" uv python find "$PYTHON_VERSION"
@ -19,6 +28,3 @@ fi
# Run unit tests with coverage # Run unit tests with coverage
uv run --python "$PYTHON_VERSION" --with-editable . --group unit \ uv run --python "$PYTHON_VERSION" --with-editable . --group unit \
coverage run --source=llama_stack -m pytest -s -v tests/unit/ "$@" coverage run --source=llama_stack -m pytest -s -v tests/unit/ "$@"
# Generate HTML coverage report
uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION

View file

@ -3,8 +3,7 @@ distribution_spec:
description: Custom distro for CI tests description: Custom distro for CI tests
providers: providers:
weather: weather:
- provider_id: kaze - provider_type: remote::kaze
provider_type: remote::kaze
image_type: venv image_type: venv
image_name: ci-test image_name: ci-test
external_providers_dir: ~/.llama/providers.d external_providers_dir: ~/.llama/providers.d

View file

@ -4,8 +4,7 @@ distribution_spec:
container_image: null container_image: null
providers: providers:
inference: inference:
- provider_id: ramalama - provider_type: remote::ramalama
provider_type: remote::ramalama
module: ramalama_stack==0.3.0a0 module: ramalama_stack==0.3.0a0
image_type: venv image_type: venv
image_name: ramalama-stack-test image_name: ramalama-stack-test

View file

@ -5,8 +5,14 @@
# the root directory of this source tree. # the root directory of this source tree.
import base64
import os
import tempfile
import pytest import pytest
from openai import OpenAI from openai import OpenAI
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
@ -82,6 +88,14 @@ def skip_if_provider_isnt_vllm(client_with_models, model_id):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.")
def skip_if_provider_isnt_openai(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type != "remote::openai":
pytest.skip(
f"Model {model_id} hosted by {provider.provider_type} doesn't support chat completion calls with base64 encoded files."
)
@pytest.fixture @pytest.fixture
def openai_client(client_with_models): def openai_client(client_with_models):
base_url = f"{client_with_models.base_url}/v1/openai/v1" base_url = f"{client_with_models.base_url}/v1/openai/v1"
@ -418,3 +432,45 @@ def test_inference_store_tool_calls(compat_client, client_with_models, text_mode
# failed tool call parses show up as a message with content, so ensure # failed tool call parses show up as a message with content, so ensure
# that the retrieve response content matches the original request # that the retrieve response content matches the original request
assert retrieved_response.choices[0].message.content == content assert retrieved_response.choices[0].message.content == content
def test_openai_chat_completion_non_streaming_with_file(openai_client, client_with_models, text_model_id):
skip_if_provider_isnt_openai(client_with_models, text_model_id)
# Generate temporary PDF with "Hello World" text
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_pdf:
c = canvas.Canvas(temp_pdf.name, pagesize=letter)
c.drawString(100, 750, "Hello World")
c.save()
# Read the PDF and sencode to base64
with open(temp_pdf.name, "rb") as pdf_file:
pdf_base64 = base64.b64encode(pdf_file.read()).decode("utf-8")
# Clean up temporary file
os.unlink(temp_pdf.name)
response = openai_client.chat.completions.create(
model=text_model_id,
messages=[
{
"role": "user",
"content": "Describe what you see in this PDF file.",
},
{
"role": "user",
"content": [
{
"type": "file",
"file": {
"filename": "my-temp-hello-world-pdf",
"file_data": f"data:application/pdf;base64,{pdf_base64}",
},
}
],
},
],
stream=False,
)
message_content = response.choices[0].message.content.lower().strip()
assert "hello world" in message_content

View file

@ -502,7 +502,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
# Find the user model and provider model # Find the user model and provider model
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None) user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
provider_model = next((m for m in models.data if m.identifier == "different-model"), None) provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
assert user_model is not None assert user_model is not None
assert user_model.source == RegistryEntrySource.via_register_api assert user_model.source == RegistryEntrySource.via_register_api
@ -558,12 +558,12 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
identifiers = {m.identifier for m in models.data} identifiers = {m.identifier for m in models.data}
assert "test_provider/user-model" in identifiers # User model preserved assert "test_provider/user-model" in identifiers # User model preserved
assert "provider-model-new" in identifiers # New provider model (uses provider's identifier) assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier)
assert "provider-model-old" not in identifiers # Old provider model removed assert "test_provider/provider-model-old" not in identifiers # Old provider model removed
# Verify sources are correct # Verify sources are correct
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None) user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None) provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None)
assert user_model.source == RegistryEntrySource.via_register_api assert user_model.source == RegistryEntrySource.via_register_api
assert provider_model.source == RegistryEntrySource.listed_from_provider assert provider_model.source == RegistryEntrySource.listed_from_provider

View file

@ -346,7 +346,7 @@ pip_packages:
def test_external_provider_from_module_building(self, mock_providers): def test_external_provider_from_module_building(self, mock_providers):
"""Test loading an external provider from a module during build (building=True, partial spec).""" """Test loading an external provider from a module during build (building=True, partial spec)."""
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec, Provider from llama_stack.distribution.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
# No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec # No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec
@ -358,10 +358,8 @@ pip_packages:
description="test", description="test",
providers={ providers={
"inference": [ "inference": [
Provider( BuildProvider(
provider_id="external_test",
provider_type="external_test", provider_type="external_test",
config={},
module="external_test", module="external_test",
) )
] ]

View file

@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from unittest.mock import AsyncMock, MagicMock, patch
from llama_stack.distribution.stack import replace_env_vars
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
class TestOpenAIBaseURLConfig:
"""Test that OPENAI_BASE_URL environment variable properly configures the OpenAI adapter."""
def test_default_base_url_without_env_var(self):
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
config = OpenAIConfig(api_key="test-key")
adapter = OpenAIInferenceAdapter(config)
assert adapter.get_base_url() == "https://api.openai.com/v1"
def test_custom_base_url_from_config(self):
"""Test that the adapter uses a custom base URL when provided in config."""
custom_url = "https://custom.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
assert adapter.get_base_url() == custom_url
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
def test_base_url_from_environment_variable(self):
"""Test that the adapter uses base URL from OPENAI_BASE_URL environment variable."""
# Use sample_run_config which has proper environment variable syntax
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
assert adapter.get_base_url() == "https://env.openai.com/v1"
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
def test_config_overrides_environment_variable(self):
"""Test that explicit config value overrides environment variable."""
custom_url = "https://config.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
# Config should take precedence over environment variable
assert adapter.get_base_url() == custom_url
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
def test_client_uses_configured_base_url(self, mock_openai_class):
"""Test that the OpenAI client is initialized with the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
adapter.get_api_key = MagicMock(return_value="test-key")
# Access the client property to trigger AsyncOpenAI initialization
_ = adapter.client
# Verify AsyncOpenAI was called with the correct base_url
mock_openai_class.assert_called_once_with(
api_key="test-key",
base_url=custom_url,
)
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
async def test_check_model_availability_uses_configured_url(self, mock_openai_class):
"""Test that check_model_availability uses the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
# Mock the AsyncOpenAI client and its models.retrieve method
mock_client = MagicMock()
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
mock_openai_class.return_value = mock_client
# Call check_model_availability and verify it returns True
assert await adapter.check_model_availability("gpt-4")
# Verify the client was created with the custom URL
mock_openai_class.assert_called_with(
api_key="test-key",
base_url=custom_url,
)
# Verify the method was called and returned True
mock_client.models.retrieve.assert_called_once_with("gpt-4")
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
async def test_environment_variable_affects_model_availability_check(self, mock_openai_class):
"""Test that setting OPENAI_BASE_URL environment variable affects where model availability is checked."""
# Use sample_run_config which has proper environment variable syntax
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
# Mock the AsyncOpenAI client
mock_client = MagicMock()
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
mock_openai_class.return_value = mock_client
# Call check_model_availability and verify it returns True
assert await adapter.check_model_availability("gpt-4")
# Verify the client was created with the environment variable URL
mock_openai_class.assert_called_with(
api_key="test-key",
base_url="https://proxy.openai.com/v1",
)

View file

@ -4,13 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest
from pydantic import ValidationError
from llama_stack.apis.common.content_types import TextContentItem from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam, OpenAIChatCompletionContentPartTextParam,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAISystemMessageParam, OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam, OpenAIUserMessageParam,
SystemMessage, SystemMessage,
UserMessage, UserMessage,
@ -108,3 +114,71 @@ async def test_openai_messages_to_messages_with_content_list():
assert llama_messages[0].content[0].text == "system message" assert llama_messages[0].content[0].text == "system message"
assert llama_messages[1].content[0].text == "user message" assert llama_messages[1].content[0].text == "user message"
assert llama_messages[2].content[0].text == "assistant message" assert llama_messages[2].content[0].text == "assistant message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_string(message_class, kwargs):
"""Test that messages accept string text content."""
msg = message_class(content="Test message", **kwargs)
assert msg.content == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_list(message_class, kwargs):
"""Test that messages accept list of text content parts."""
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
msg = message_class(content=content_list, **kwargs)
assert len(msg.content) == 1
assert msg.content[0].text == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_rejects_images(message_class, kwargs):
"""Test that system, assistant, developer, and tool messages reject image content."""
with pytest.raises(ValidationError):
message_class(
content=[
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
],
**kwargs,
)
def test_user_message_accepts_images():
"""Test that user messages accept image content (unlike other message types)."""
# List with images should work
msg = OpenAIUserMessageParam(
content=[
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
]
)
assert len(msg.content) == 2
assert msg.content[0].text == "Describe this image:"
assert msg.content[1].image_url.url == "http://example.com/image.jpg"

View file

@ -162,26 +162,29 @@ async def test_register_model_existing_different(
await helper.register_model(known_model) await helper.register_model(known_model)
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None: # TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
await helper.register_model(known_model) # duplicate entry # async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id # await helper.register_model(known_model) # duplicate entry
await helper.unregister_model(known_model.model_id) # assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
assert helper.get_provider_model_id(known_model.model_id) is None # await helper.unregister_model(known_model.model_id)
# assert helper.get_provider_model_id(known_model.model_id) is None
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: # TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
with pytest.raises(ValueError): # async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
await helper.unregister_model(unknown_model.model_id) # with pytest.raises(ValueError):
# await helper.unregister_model(unknown_model.model_id)
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None: # TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id # async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
await helper.unregister_model(known_model.provider_resource_id) # assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
assert helper.get_provider_model_id(known_model.provider_resource_id) is None # await helper.unregister_model(known_model.provider_resource_id)
# assert helper.get_provider_model_id(known_model.provider_resource_id) is None
async def test_register_model_from_check_model_availability( async def test_register_model_from_check_model_availability(

View file

@ -49,7 +49,7 @@ def github_token_app():
) )
# Add auth middleware # Add auth middleware
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test") @app.get("/test")
def test_endpoint(): def test_endpoint():
@ -149,7 +149,7 @@ def test_github_enterprise_support(mock_client_class):
access_policy=[], access_policy=[],
) )
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test") @app.get("/test")
def test_endpoint(): def test_endpoint():

Some files were not shown because too many files have changed in this diff Show more