Merge branch 'main' into fix/issue-3797-metadata-validation

This commit is contained in:
Ashwin Bharambe 2025-11-19 10:09:30 -08:00 committed by GitHub
commit 0358770791
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
90 changed files with 6769 additions and 247 deletions

View file

@ -0,0 +1,35 @@
name: Setup TypeScript client
description: Conditionally checkout and link llama-stack-client-typescript based on client-version
inputs:
client-version:
description: 'Client version (latest or published)'
required: true
outputs:
ts-client-path:
description: 'Path or version to use for TypeScript client'
value: ${{ steps.set-path.outputs.ts-client-path }}
runs:
using: "composite"
steps:
- name: Checkout TypeScript client (latest)
if: ${{ inputs.client-version == 'latest' }}
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
repository: llamastack/llama-stack-client-typescript
ref: main
path: .ts-client-checkout
- name: Set TS_CLIENT_PATH
id: set-path
shell: bash
run: |
if [ "${{ inputs.client-version }}" = "latest" ]; then
echo "ts-client-path=${{ github.workspace }}/.ts-client-checkout" >> $GITHUB_OUTPUT
elif [ "${{ inputs.client-version }}" = "published" ]; then
echo "ts-client-path=^0.3.2" >> $GITHUB_OUTPUT
else
echo "::error::Invalid client-version: ${{ inputs.client-version }}"
exit 1
fi

View file

@ -93,11 +93,27 @@ jobs:
suite: ${{ matrix.config.suite }} suite: ${{ matrix.config.suite }}
inference-mode: 'replay' inference-mode: 'replay'
- name: Setup Node.js for TypeScript client tests
if: ${{ matrix.client == 'server' }}
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
with:
node-version: '20'
cache: 'npm'
cache-dependency-path: tests/integration/client-typescript/package-lock.json
- name: Setup TypeScript client
if: ${{ matrix.client == 'server' }}
id: setup-ts-client
uses: ./.github/actions/setup-typescript-client
with:
client-version: ${{ matrix.client-version }}
- name: Run tests - name: Run tests
if: ${{ matrix.config.allowed_clients == null || contains(matrix.config.allowed_clients, matrix.client) }} if: ${{ matrix.config.allowed_clients == null || contains(matrix.config.allowed_clients, matrix.client) }}
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
env: env:
OPENAI_API_KEY: dummy OPENAI_API_KEY: dummy
TS_CLIENT_PATH: ${{ steps.setup-ts-client.outputs.ts-client-path || '' }}
with: with:
stack-config: >- stack-config: >-
${{ matrix.config.stack_config ${{ matrix.config.stack_config

View file

@ -59,6 +59,30 @@ jobs:
ref: ${{ github.event.pull_request.head.sha }} ref: ${{ github.event.pull_request.head.sha }}
fetch-depth: 2 fetch-depth: 2
# Compute the Stainless branch name, prefixing with fork owner if PR is from a fork.
# For fork PRs like "contributor:fix/issue-123", this creates "preview/contributor/fix/issue-123"
# For same-repo PRs, this creates "preview/fix/issue-123"
- name: Compute branch names
id: branch-names
run: |
HEAD_REPO="${{ github.event.pull_request.head.repo.full_name }}"
BASE_REPO="${{ github.repository }}"
BRANCH_NAME="${{ github.event.pull_request.head.ref }}"
if [ "$HEAD_REPO" != "$BASE_REPO" ]; then
# Fork PR: prefix with fork owner for isolation
FORK_OWNER="${{ github.event.pull_request.head.repo.owner.login }}"
PREVIEW_BRANCH="preview/${FORK_OWNER}/${BRANCH_NAME}"
BASE_BRANCH="preview/base/${FORK_OWNER}/${BRANCH_NAME}"
else
# Same-repo PR
PREVIEW_BRANCH="preview/${BRANCH_NAME}"
BASE_BRANCH="preview/base/${BRANCH_NAME}"
fi
echo "preview_branch=${PREVIEW_BRANCH}" >> $GITHUB_OUTPUT
echo "base_branch=${BASE_BRANCH}" >> $GITHUB_OUTPUT
# This action builds preview SDKs from the OpenAPI spec changes and # This action builds preview SDKs from the OpenAPI spec changes and
# posts/updates a comment on the PR with build results and links to the preview. # posts/updates a comment on the PR with build results and links to the preview.
- name: Run preview builds - name: Run preview builds
@ -73,6 +97,8 @@ jobs:
base_sha: ${{ github.event.pull_request.base.sha }} base_sha: ${{ github.event.pull_request.base.sha }}
base_ref: ${{ github.event.pull_request.base.ref }} base_ref: ${{ github.event.pull_request.base.ref }}
head_sha: ${{ github.event.pull_request.head.sha }} head_sha: ${{ github.event.pull_request.head.sha }}
branch: ${{ steps.branch-names.outputs.preview_branch }}
base_branch: ${{ steps.branch-names.outputs.base_branch }}
merge: merge:
if: github.event.action == 'closed' && github.event.pull_request.merged == true if: github.event.action == 'closed' && github.event.pull_request.merged == true
@ -90,12 +116,33 @@ jobs:
ref: ${{ github.event.pull_request.head.sha }} ref: ${{ github.event.pull_request.head.sha }}
fetch-depth: 2 fetch-depth: 2
# Compute the Stainless branch name, prefixing with fork owner if PR is from a fork.
# For fork PRs like "contributor:fix/issue-123", this creates "preview/contributor/fix/issue-123"
# For same-repo PRs, this creates "preview/fix/issue-123"
- name: Compute branch names
id: branch-names
run: |
HEAD_REPO="${{ github.event.pull_request.head.repo.full_name }}"
BASE_REPO="${{ github.repository }}"
BRANCH_NAME="${{ github.event.pull_request.head.ref }}"
if [ "$HEAD_REPO" != "$BASE_REPO" ]; then
# Fork PR: prefix with fork owner for isolation
FORK_OWNER="${{ github.event.pull_request.head.repo.owner.login }}"
MERGE_BRANCH="preview/${FORK_OWNER}/${BRANCH_NAME}"
else
# Same-repo PR
MERGE_BRANCH="preview/${BRANCH_NAME}"
fi
echo "merge_branch=${MERGE_BRANCH}" >> $GITHUB_OUTPUT
# Note that this only merges in changes that happened on the last build on # Note that this only merges in changes that happened on the last build on
# preview/${{ github.head_ref }}. It's possible that there are OAS/config # the computed preview branch. It's possible that there are OAS/config
# changes that haven't been built, if the preview-sdk job didn't finish # changes that haven't been built, if the preview job didn't finish
# before this step starts. In theory we want to wait for all builds # before this step starts. In theory we want to wait for all builds
# against preview/${{ github.head_ref }} to complete, but assuming that # against the preview branch to complete, but assuming that
# the preview-sdk job happens before the PR merge, it should be fine. # the preview job happens before the PR merge, it should be fine.
- name: Run merge build - name: Run merge build
uses: stainless-api/upload-openapi-spec-action/merge@32823b096b4319c53ee948d702d9052873af485f # 1.6.0 uses: stainless-api/upload-openapi-spec-action/merge@32823b096b4319c53ee948d702d9052873af485f # 1.6.0
with: with:
@ -108,3 +155,4 @@ jobs:
base_sha: ${{ github.event.pull_request.base.sha }} base_sha: ${{ github.event.pull_request.base.sha }}
base_ref: ${{ github.event.pull_request.base.ref }} base_ref: ${{ github.event.pull_request.base.ref }}
head_sha: ${{ github.event.pull_request.head.sha }} head_sha: ${{ github.event.pull_request.head.sha }}
merge_branch: ${{ steps.branch-names.outputs.merge_branch }}

2
.gitignore vendored
View file

@ -35,3 +35,5 @@ docs/static/imported-files/
docs/docs/api-deprecated/ docs/docs/api-deprecated/
docs/docs/api-experimental/ docs/docs/api-experimental/
docs/docs/api/ docs/docs/api/
tests/integration/client-typescript/node_modules/
.ts-client-checkout/

View file

@ -24,7 +24,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider | | `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
| `api_base` | `HttpUrl` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) | | `base_url` | `HttpUrl \| None` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com/openai/v1) |
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) | | `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
| `api_type` | `str \| None` | No | azure | Azure API type for Azure (e.g., azure) | | `api_type` | `str \| None` | No | azure | Azure API type for Azure (e.g., azure) |
@ -32,7 +32,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
```yaml ```yaml
api_key: ${env.AZURE_API_KEY:=} api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=} base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=} api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=} api_type: ${env.AZURE_API_TYPE:=}
``` ```

View file

@ -17,11 +17,11 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider | | `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
| `base_url` | `str` | No | https://api.cerebras.ai | Base URL for the Cerebras API | | `base_url` | `HttpUrl \| None` | No | https://api.cerebras.ai/v1 | Base URL for the Cerebras API |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
base_url: https://api.cerebras.ai base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=} api_key: ${env.CEREBRAS_API_KEY:=}
``` ```

View file

@ -17,11 +17,11 @@ Databricks inference provider for running models on Databricks' unified analytic
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_token` | `SecretStr \| None` | No | | The Databricks API token | | `api_token` | `SecretStr \| None` | No | | The Databricks API token |
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint | | `base_url` | `HttpUrl \| None` | No | | The URL for the Databricks model serving endpoint (should include /serving-endpoints path) |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.DATABRICKS_HOST:=} base_url: ${env.DATABRICKS_HOST:=}
api_token: ${env.DATABRICKS_TOKEN:=} api_token: ${env.DATABRICKS_TOKEN:=}
``` ```

View file

@ -17,11 +17,11 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider | | `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `str` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server | | `base_url` | `HttpUrl \| None` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: https://api.fireworks.ai/inference/v1 base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=} api_key: ${env.FIREWORKS_API_KEY:=}
``` ```

View file

@ -17,11 +17,11 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider | | `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `str` | No | https://api.groq.com | The URL for the Groq AI server | | `base_url` | `HttpUrl \| None` | No | https://api.groq.com/openai/v1 | The URL for the Groq AI server |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: https://api.groq.com base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=} api_key: ${env.GROQ_API_KEY:=}
``` ```

View file

@ -17,11 +17,11 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider | | `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
| `openai_compat_api_base` | `str` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server | | `base_url` | `HttpUrl \| None` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
openai_compat_api_base: https://api.llama.com/compat/v1/ base_url: https://api.llama.com/compat/v1/
api_key: ${env.LLAMA_API_KEY} api_key: ${env.LLAMA_API_KEY}
``` ```

View file

@ -17,15 +17,13 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider | | `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `str` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM | | `base_url` | `HttpUrl \| None` | No | https://integrate.api.nvidia.com/v1 | A base url for accessing the NVIDIA NIM |
| `timeout` | `int` | No | 60 | Timeout for the HTTP requests | | `timeout` | `int` | No | 60 | Timeout for the HTTP requests |
| `append_api_version` | `bool` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
| `rerank_model_to_url` | `dict[str, str]` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. | | `rerank_model_to_url` | `dict[str, str]` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=} api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
``` ```

View file

@ -16,10 +16,10 @@ Ollama inference provider for running local models through the Ollama runtime.
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `url` | `str` | No | http://localhost:11434 | | | `base_url` | `HttpUrl \| None` | No | http://localhost:11434/v1 | |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.OLLAMA_URL:=http://localhost:11434} base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
``` ```

View file

@ -17,7 +17,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider | | `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
| `base_url` | `str` | No | https://api.openai.com/v1 | Base URL for OpenAI API | | `base_url` | `HttpUrl \| None` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
## Sample Configuration ## Sample Configuration

View file

@ -17,11 +17,11 @@ Passthrough inference provider for connecting to any external inference service
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider | | `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `str` | No | | The URL for the passthrough endpoint | | `base_url` | `HttpUrl \| None` | No | | The URL for the passthrough endpoint |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.PASSTHROUGH_URL} base_url: ${env.PASSTHROUGH_URL}
api_key: ${env.PASSTHROUGH_API_KEY} api_key: ${env.PASSTHROUGH_API_KEY}
``` ```

View file

@ -17,11 +17,11 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_token` | `SecretStr \| None` | No | | The API token | | `api_token` | `SecretStr \| None` | No | | The API token |
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint | | `base_url` | `HttpUrl \| None` | No | | The URL for the Runpod model serving endpoint |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.RUNPOD_URL:=} base_url: ${env.RUNPOD_URL:=}
api_token: ${env.RUNPOD_API_TOKEN} api_token: ${env.RUNPOD_API_TOKEN}
``` ```

View file

@ -17,11 +17,11 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider | | `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `str` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server | | `base_url` | `HttpUrl \| None` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: https://api.sambanova.ai/v1 base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=} api_key: ${env.SAMBANOVA_API_KEY:=}
``` ```

View file

@ -16,10 +16,10 @@ Text Generation Inference (TGI) provider for HuggingFace model serving.
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `url` | `str` | No | | The URL for the TGI serving endpoint | | `base_url` | `HttpUrl \| None` | No | | The URL for the TGI serving endpoint (should include /v1 path) |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.TGI_URL:=} base_url: ${env.TGI_URL:=}
``` ```

View file

@ -17,11 +17,11 @@ Together AI inference provider for open-source models and collaborative AI devel
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider | | `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `str` | No | https://api.together.xyz/v1 | The URL for the Together AI server | | `base_url` | `HttpUrl \| None` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: https://api.together.xyz/v1 base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=} api_key: ${env.TOGETHER_API_KEY:=}
``` ```

View file

@ -17,14 +17,14 @@ Remote vLLM inference provider for connecting to vLLM servers.
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_token` | `SecretStr \| None` | No | | The API token | | `api_token` | `SecretStr \| None` | No | | The API token |
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint | | `base_url` | `HttpUrl \| None` | No | | The URL for the vLLM model serving endpoint |
| `max_tokens` | `int` | No | 4096 | Maximum number of tokens to generate. | | `max_tokens` | `int` | No | 4096 | Maximum number of tokens to generate. |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.VLLM_URL:=} base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}

View file

@ -17,14 +17,14 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider | | `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider | | `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
| `url` | `str` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai | | `base_url` | `HttpUrl \| None` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
| `project_id` | `str \| None` | No | | The watsonx.ai project ID | | `project_id` | `str \| None` | No | | The watsonx.ai project ID |
| `timeout` | `int` | No | 60 | Timeout for the HTTP requests | | `timeout` | `int` | No | 60 | Timeout for the HTTP requests |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com} base_url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
api_key: ${env.WATSONX_API_KEY:=} api_key: ${env.WATSONX_API_KEY:=}
project_id: ${env.WATSONX_PROJECT_ID:=} project_id: ${env.WATSONX_PROJECT_ID:=}
``` ```

View file

@ -287,9 +287,9 @@ start_container() {
# On macOS/Windows, use host.docker.internal to reach host from container # On macOS/Windows, use host.docker.internal to reach host from container
# On Linux with --network host, use localhost # On Linux with --network host, use localhost
if [[ "$(uname)" == "Darwin" ]] || [[ "$(uname)" == *"MINGW"* ]]; then if [[ "$(uname)" == "Darwin" ]] || [[ "$(uname)" == *"MINGW"* ]]; then
OLLAMA_URL="${OLLAMA_URL:-http://host.docker.internal:11434}" OLLAMA_URL="${OLLAMA_URL:-http://host.docker.internal:11434/v1}"
else else
OLLAMA_URL="${OLLAMA_URL:-http://localhost:11434}" OLLAMA_URL="${OLLAMA_URL:-http://localhost:11434/v1}"
fi fi
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OLLAMA_URL=$OLLAMA_URL" DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OLLAMA_URL=$OLLAMA_URL"

View file

@ -16,16 +16,16 @@ import sys
from tests.integration.suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS from tests.integration.suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS
def get_setup_env_vars(setup_name, suite_name=None): def get_setup_config(setup_name, suite_name=None):
""" """
Get environment variables for a setup, with optional suite default fallback. Get full configuration (env vars + defaults) for a setup.
Args: Args:
setup_name: Name of the setup (e.g., 'ollama', 'gpt') setup_name: Name of the setup (e.g., 'ollama', 'gpt')
suite_name: Optional suite name to get default setup if setup_name is None suite_name: Optional suite name to get default setup if setup_name is None
Returns: Returns:
Dictionary of environment variables Dictionary with 'env' and 'defaults' keys
""" """
# If no setup specified, try to get default from suite # If no setup specified, try to get default from suite
if not setup_name and suite_name: if not setup_name and suite_name:
@ -34,7 +34,7 @@ def get_setup_env_vars(setup_name, suite_name=None):
setup_name = suite.default_setup setup_name = suite.default_setup
if not setup_name: if not setup_name:
return {} return {"env": {}, "defaults": {}}
setup = SETUP_DEFINITIONS.get(setup_name) setup = SETUP_DEFINITIONS.get(setup_name)
if not setup: if not setup:
@ -44,27 +44,31 @@ def get_setup_env_vars(setup_name, suite_name=None):
) )
sys.exit(1) sys.exit(1)
return setup.env return {"env": setup.env, "defaults": setup.defaults}
def main(): def main():
parser = argparse.ArgumentParser(description="Extract environment variables from a test setup") parser = argparse.ArgumentParser(description="Extract environment variables and defaults from a test setup")
parser.add_argument("--setup", help="Setup name (e.g., ollama, gpt)") parser.add_argument("--setup", help="Setup name (e.g., ollama, gpt)")
parser.add_argument("--suite", help="Suite name to get default setup from if --setup not provided") parser.add_argument("--suite", help="Suite name to get default setup from if --setup not provided")
parser.add_argument("--format", choices=["bash", "json"], default="bash", help="Output format (default: bash)") parser.add_argument("--format", choices=["bash", "json"], default="bash", help="Output format (default: bash)")
args = parser.parse_args() args = parser.parse_args()
env_vars = get_setup_env_vars(args.setup, args.suite) config = get_setup_config(args.setup, args.suite)
if args.format == "bash": if args.format == "bash":
# Output as bash export statements # Output env vars as bash export statements
for key, value in env_vars.items(): for key, value in config["env"].items():
print(f"export {key}='{value}'") print(f"export {key}='{value}'")
# Output defaults as bash export statements with LLAMA_STACK_TEST_ prefix
for key, value in config["defaults"].items():
env_key = f"LLAMA_STACK_TEST_{key.upper()}"
print(f"export {env_key}='{value}'")
elif args.format == "json": elif args.format == "json":
import json import json
print(json.dumps(env_vars)) print(json.dumps(config))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -640,7 +640,7 @@ cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
--network llama-net \ --network llama-net \
-p "${PORT}:${PORT}" \ -p "${PORT}:${PORT}" \
"${server_env_opts[@]}" \ "${server_env_opts[@]}" \
-e OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" \ -e OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}/v1" \
"${SERVER_IMAGE}" --port "${PORT}") "${SERVER_IMAGE}" --port "${PORT}")
log "🦙 Starting Llama Stack..." log "🦙 Starting Llama Stack..."

View file

@ -181,6 +181,10 @@ echo "$SETUP_ENV"
eval "$SETUP_ENV" eval "$SETUP_ENV"
echo "" echo ""
# Export suite and setup names for TypeScript tests
export LLAMA_STACK_TEST_SUITE="$TEST_SUITE"
export LLAMA_STACK_TEST_SETUP="$TEST_SETUP"
ROOT_DIR="$THIS_DIR/.." ROOT_DIR="$THIS_DIR/.."
cd $ROOT_DIR cd $ROOT_DIR
@ -212,6 +216,71 @@ find_available_port() {
return 1 return 1
} }
run_client_ts_tests() {
if ! command -v npm &>/dev/null; then
echo "npm could not be found; ensure Node.js is installed"
return 1
fi
pushd tests/integration/client-typescript >/dev/null
# Determine if TS_CLIENT_PATH is a directory path or an npm version
if [[ -d "$TS_CLIENT_PATH" ]]; then
# It's a directory path - use local checkout
if [[ ! -f "$TS_CLIENT_PATH/package.json" ]]; then
echo "Error: $TS_CLIENT_PATH exists but doesn't look like llama-stack-client-typescript (no package.json)"
popd >/dev/null
return 1
fi
echo "Using local llama-stack-client-typescript from: $TS_CLIENT_PATH"
# Build the TypeScript client first
echo "Building TypeScript client..."
pushd "$TS_CLIENT_PATH" >/dev/null
npm install --silent
npm run build --silent
popd >/dev/null
# Install other dependencies first
if [[ "${CI:-}" == "true" || "${CI:-}" == "1" ]]; then
npm ci --silent
else
npm install --silent
fi
# Then install the client from local directory
echo "Installing llama-stack-client from: $TS_CLIENT_PATH"
npm install "$TS_CLIENT_PATH" --silent
else
# It's an npm version specifier - install from npm
echo "Installing llama-stack-client@${TS_CLIENT_PATH} from npm"
if [[ "${CI:-}" == "true" || "${CI:-}" == "1" ]]; then
npm ci --silent
npm install "llama-stack-client@${TS_CLIENT_PATH}" --silent
else
npm install "llama-stack-client@${TS_CLIENT_PATH}" --silent
fi
fi
# Verify installation
echo "Verifying llama-stack-client installation..."
if npm list llama-stack-client 2>/dev/null | grep -q llama-stack-client; then
echo "✅ llama-stack-client successfully installed"
npm list llama-stack-client
else
echo "❌ llama-stack-client not found in node_modules"
echo "Installed packages:"
npm list --depth=0
popd >/dev/null
return 1
fi
echo "Running TypeScript tests for suite $TEST_SUITE (setup $TEST_SETUP)"
npm test
popd >/dev/null
}
# Start Llama Stack Server if needed # Start Llama Stack Server if needed
if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
# Find an available port for the server # Find an available port for the server
@ -221,6 +290,7 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
exit 1 exit 1
fi fi
export LLAMA_STACK_PORT export LLAMA_STACK_PORT
export TEST_API_BASE_URL="http://localhost:$LLAMA_STACK_PORT"
echo "Will use port: $LLAMA_STACK_PORT" echo "Will use port: $LLAMA_STACK_PORT"
stop_server() { stop_server() {
@ -298,6 +368,7 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
exit 1 exit 1
fi fi
export LLAMA_STACK_PORT export LLAMA_STACK_PORT
export TEST_API_BASE_URL="http://localhost:$LLAMA_STACK_PORT"
echo "Will use port: $LLAMA_STACK_PORT" echo "Will use port: $LLAMA_STACK_PORT"
echo "=== Building Docker Image for distribution: $DISTRO ===" echo "=== Building Docker Image for distribution: $DISTRO ==="
@ -506,5 +577,10 @@ else
exit 1 exit 1
fi fi
# Run TypeScript client tests if TS_CLIENT_PATH is set
if [[ $exit_code -eq 0 && -n "${TS_CLIENT_PATH:-}" && "${LLAMA_STACK_TEST_STACK_CONFIG_TYPE:-}" == "server" ]]; then
run_client_ts_tests
fi
echo "" echo ""
echo "=== Integration Tests Complete ===" echo "=== Integration Tests Complete ==="

View file

@ -17,32 +17,32 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras} - provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras provider_type: remote::cerebras
config: config:
base_url: https://api.cerebras.ai base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=} api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama} - provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama provider_type: remote::ollama
config: config:
url: ${env.OLLAMA_URL:=http://localhost:11434} base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm} - provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm provider_type: remote::vllm
config: config:
url: ${env.VLLM_URL:=} base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi} - provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi provider_type: remote::tgi
config: config:
url: ${env.TGI_URL:=} base_url: ${env.TGI_URL:=}
- provider_id: fireworks - provider_id: fireworks
provider_type: remote::fireworks provider_type: remote::fireworks
config: config:
url: https://api.fireworks.ai/inference/v1 base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=} api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together - provider_id: together
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=} api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock - provider_id: bedrock
provider_type: remote::bedrock provider_type: remote::bedrock
@ -52,9 +52,8 @@ providers:
- provider_id: ${env.NVIDIA_API_KEY:+nvidia} - provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia provider_type: remote::nvidia
config: config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=} api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai - provider_id: openai
provider_type: remote::openai provider_type: remote::openai
config: config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq - provider_id: groq
provider_type: remote::groq provider_type: remote::groq
config: config:
url: https://api.groq.com base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=} api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova - provider_id: sambanova
provider_type: remote::sambanova provider_type: remote::sambanova
config: config:
url: https://api.sambanova.ai/v1 base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=} api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure} - provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure provider_type: remote::azure
config: config:
api_key: ${env.AZURE_API_KEY:=} api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=} base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=} api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=} api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers - provider_id: sentence-transformers

View file

@ -17,32 +17,32 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras} - provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras provider_type: remote::cerebras
config: config:
base_url: https://api.cerebras.ai base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=} api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama} - provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama provider_type: remote::ollama
config: config:
url: ${env.OLLAMA_URL:=http://localhost:11434} base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm} - provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm provider_type: remote::vllm
config: config:
url: ${env.VLLM_URL:=} base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi} - provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi provider_type: remote::tgi
config: config:
url: ${env.TGI_URL:=} base_url: ${env.TGI_URL:=}
- provider_id: fireworks - provider_id: fireworks
provider_type: remote::fireworks provider_type: remote::fireworks
config: config:
url: https://api.fireworks.ai/inference/v1 base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=} api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together - provider_id: together
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=} api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock - provider_id: bedrock
provider_type: remote::bedrock provider_type: remote::bedrock
@ -52,9 +52,8 @@ providers:
- provider_id: ${env.NVIDIA_API_KEY:+nvidia} - provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia provider_type: remote::nvidia
config: config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=} api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai - provider_id: openai
provider_type: remote::openai provider_type: remote::openai
config: config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq - provider_id: groq
provider_type: remote::groq provider_type: remote::groq
config: config:
url: https://api.groq.com base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=} api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova - provider_id: sambanova
provider_type: remote::sambanova provider_type: remote::sambanova
config: config:
url: https://api.sambanova.ai/v1 base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=} api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure} - provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure provider_type: remote::azure
config: config:
api_key: ${env.AZURE_API_KEY:=} api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=} base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=} api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=} api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers - provider_id: sentence-transformers

View file

@ -16,9 +16,8 @@ providers:
- provider_id: nvidia - provider_id: nvidia
provider_type: remote::nvidia provider_type: remote::nvidia
config: config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=} api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: nvidia - provider_id: nvidia
provider_type: remote::nvidia provider_type: remote::nvidia
config: config:

View file

@ -16,9 +16,8 @@ providers:
- provider_id: nvidia - provider_id: nvidia
provider_type: remote::nvidia provider_type: remote::nvidia
config: config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=} api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
vector_io: vector_io:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss

View file

@ -27,12 +27,12 @@ providers:
- provider_id: groq - provider_id: groq
provider_type: remote::groq provider_type: remote::groq
config: config:
url: https://api.groq.com base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=} api_key: ${env.GROQ_API_KEY:=}
- provider_id: together - provider_id: together
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=} api_key: ${env.TOGETHER_API_KEY:=}
vector_io: vector_io:
- provider_id: sqlite-vec - provider_id: sqlite-vec

View file

@ -11,7 +11,7 @@ providers:
- provider_id: vllm-inference - provider_id: vllm-inference
provider_type: remote::vllm provider_type: remote::vllm
config: config:
url: ${env.VLLM_URL:=http://localhost:8000/v1} base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}

View file

@ -17,32 +17,32 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras} - provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras provider_type: remote::cerebras
config: config:
base_url: https://api.cerebras.ai base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=} api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama} - provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama provider_type: remote::ollama
config: config:
url: ${env.OLLAMA_URL:=http://localhost:11434} base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm} - provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm provider_type: remote::vllm
config: config:
url: ${env.VLLM_URL:=} base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi} - provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi provider_type: remote::tgi
config: config:
url: ${env.TGI_URL:=} base_url: ${env.TGI_URL:=}
- provider_id: fireworks - provider_id: fireworks
provider_type: remote::fireworks provider_type: remote::fireworks
config: config:
url: https://api.fireworks.ai/inference/v1 base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=} api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together - provider_id: together
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=} api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock - provider_id: bedrock
provider_type: remote::bedrock provider_type: remote::bedrock
@ -52,9 +52,8 @@ providers:
- provider_id: ${env.NVIDIA_API_KEY:+nvidia} - provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia provider_type: remote::nvidia
config: config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=} api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai - provider_id: openai
provider_type: remote::openai provider_type: remote::openai
config: config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq - provider_id: groq
provider_type: remote::groq provider_type: remote::groq
config: config:
url: https://api.groq.com base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=} api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova - provider_id: sambanova
provider_type: remote::sambanova provider_type: remote::sambanova
config: config:
url: https://api.sambanova.ai/v1 base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=} api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure} - provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure provider_type: remote::azure
config: config:
api_key: ${env.AZURE_API_KEY:=} api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=} base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=} api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=} api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers - provider_id: sentence-transformers

View file

@ -17,32 +17,32 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras} - provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras provider_type: remote::cerebras
config: config:
base_url: https://api.cerebras.ai base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=} api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama} - provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama provider_type: remote::ollama
config: config:
url: ${env.OLLAMA_URL:=http://localhost:11434} base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm} - provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm provider_type: remote::vllm
config: config:
url: ${env.VLLM_URL:=} base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi} - provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi provider_type: remote::tgi
config: config:
url: ${env.TGI_URL:=} base_url: ${env.TGI_URL:=}
- provider_id: fireworks - provider_id: fireworks
provider_type: remote::fireworks provider_type: remote::fireworks
config: config:
url: https://api.fireworks.ai/inference/v1 base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=} api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together - provider_id: together
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=} api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock - provider_id: bedrock
provider_type: remote::bedrock provider_type: remote::bedrock
@ -52,9 +52,8 @@ providers:
- provider_id: ${env.NVIDIA_API_KEY:+nvidia} - provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia provider_type: remote::nvidia
config: config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=} api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai - provider_id: openai
provider_type: remote::openai provider_type: remote::openai
config: config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq - provider_id: groq
provider_type: remote::groq provider_type: remote::groq
config: config:
url: https://api.groq.com base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=} api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova - provider_id: sambanova
provider_type: remote::sambanova provider_type: remote::sambanova
config: config:
url: https://api.sambanova.ai/v1 base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=} api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure} - provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure provider_type: remote::azure
config: config:
api_key: ${env.AZURE_API_KEY:=} api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=} base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=} api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=} api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers - provider_id: sentence-transformers

View file

@ -17,32 +17,32 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras} - provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras provider_type: remote::cerebras
config: config:
base_url: https://api.cerebras.ai base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=} api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama} - provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama provider_type: remote::ollama
config: config:
url: ${env.OLLAMA_URL:=http://localhost:11434} base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm} - provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm provider_type: remote::vllm
config: config:
url: ${env.VLLM_URL:=} base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi} - provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi provider_type: remote::tgi
config: config:
url: ${env.TGI_URL:=} base_url: ${env.TGI_URL:=}
- provider_id: fireworks - provider_id: fireworks
provider_type: remote::fireworks provider_type: remote::fireworks
config: config:
url: https://api.fireworks.ai/inference/v1 base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=} api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together - provider_id: together
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=} api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock - provider_id: bedrock
provider_type: remote::bedrock provider_type: remote::bedrock
@ -52,9 +52,8 @@ providers:
- provider_id: ${env.NVIDIA_API_KEY:+nvidia} - provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia provider_type: remote::nvidia
config: config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=} api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai - provider_id: openai
provider_type: remote::openai provider_type: remote::openai
config: config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq - provider_id: groq
provider_type: remote::groq provider_type: remote::groq
config: config:
url: https://api.groq.com base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=} api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova - provider_id: sambanova
provider_type: remote::sambanova provider_type: remote::sambanova
config: config:
url: https://api.sambanova.ai/v1 base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=} api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure} - provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure provider_type: remote::azure
config: config:
api_key: ${env.AZURE_API_KEY:=} api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=} base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=} api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=} api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers - provider_id: sentence-transformers

View file

@ -17,32 +17,32 @@ providers:
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras} - provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
provider_type: remote::cerebras provider_type: remote::cerebras
config: config:
base_url: https://api.cerebras.ai base_url: https://api.cerebras.ai/v1
api_key: ${env.CEREBRAS_API_KEY:=} api_key: ${env.CEREBRAS_API_KEY:=}
- provider_id: ${env.OLLAMA_URL:+ollama} - provider_id: ${env.OLLAMA_URL:+ollama}
provider_type: remote::ollama provider_type: remote::ollama
config: config:
url: ${env.OLLAMA_URL:=http://localhost:11434} base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
- provider_id: ${env.VLLM_URL:+vllm} - provider_id: ${env.VLLM_URL:+vllm}
provider_type: remote::vllm provider_type: remote::vllm
config: config:
url: ${env.VLLM_URL:=} base_url: ${env.VLLM_URL:=}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: ${env.TGI_URL:+tgi} - provider_id: ${env.TGI_URL:+tgi}
provider_type: remote::tgi provider_type: remote::tgi
config: config:
url: ${env.TGI_URL:=} base_url: ${env.TGI_URL:=}
- provider_id: fireworks - provider_id: fireworks
provider_type: remote::fireworks provider_type: remote::fireworks
config: config:
url: https://api.fireworks.ai/inference/v1 base_url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY:=} api_key: ${env.FIREWORKS_API_KEY:=}
- provider_id: together - provider_id: together
provider_type: remote::together provider_type: remote::together
config: config:
url: https://api.together.xyz/v1 base_url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY:=} api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock - provider_id: bedrock
provider_type: remote::bedrock provider_type: remote::bedrock
@ -52,9 +52,8 @@ providers:
- provider_id: ${env.NVIDIA_API_KEY:+nvidia} - provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia provider_type: remote::nvidia
config: config:
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
api_key: ${env.NVIDIA_API_KEY:=} api_key: ${env.NVIDIA_API_KEY:=}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
- provider_id: openai - provider_id: openai
provider_type: remote::openai provider_type: remote::openai
config: config:
@ -76,18 +75,18 @@ providers:
- provider_id: groq - provider_id: groq
provider_type: remote::groq provider_type: remote::groq
config: config:
url: https://api.groq.com base_url: https://api.groq.com/openai/v1
api_key: ${env.GROQ_API_KEY:=} api_key: ${env.GROQ_API_KEY:=}
- provider_id: sambanova - provider_id: sambanova
provider_type: remote::sambanova provider_type: remote::sambanova
config: config:
url: https://api.sambanova.ai/v1 base_url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=} api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure} - provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure provider_type: remote::azure
config: config:
api_key: ${env.AZURE_API_KEY:=} api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=} base_url: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=} api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=} api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers - provider_id: sentence-transformers

View file

@ -15,7 +15,7 @@ providers:
- provider_id: watsonx - provider_id: watsonx
provider_type: remote::watsonx provider_type: remote::watsonx
config: config:
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com} base_url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
api_key: ${env.WATSONX_API_KEY:=} api_key: ${env.WATSONX_API_KEY:=}
project_id: ${env.WATSONX_PROJECT_ID:=} project_id: ${env.WATSONX_PROJECT_ID:=}
vector_io: vector_io:

View file

@ -23,7 +23,7 @@ async def get_provider_impl(
config, config,
deps[Api.inference], deps[Api.inference],
deps[Api.vector_io], deps[Api.vector_io],
deps[Api.safety], deps.get(Api.safety),
deps[Api.tool_runtime], deps[Api.tool_runtime],
deps[Api.tool_groups], deps[Api.tool_groups],
deps[Api.conversations], deps[Api.conversations],

View file

@ -41,7 +41,7 @@ class MetaReferenceAgentsImpl(Agents):
config: MetaReferenceAgentsImplConfig, config: MetaReferenceAgentsImplConfig,
inference_api: Inference, inference_api: Inference,
vector_io_api: VectorIO, vector_io_api: VectorIO,
safety_api: Safety, safety_api: Safety | None,
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
conversations_api: Conversations, conversations_api: Conversations,

View file

@ -67,7 +67,7 @@ class OpenAIResponsesImpl:
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore, responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO vector_io_api: VectorIO, # VectorIO
safety_api: Safety, safety_api: Safety | None,
conversations_api: Conversations, conversations_api: Conversations,
): ):
self.inference_api = inference_api self.inference_api = inference_api
@ -273,6 +273,14 @@ class OpenAIResponsesImpl:
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else [] guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
# Validate that Safety API is available if guardrails are requested
if guardrail_ids and self.safety_api is None:
raise ValueError(
"Cannot process guardrails: Safety API is not configured.\n\n"
"To use guardrails, ensure the Safety API is configured in your stack, or remove "
"the 'guardrails' parameter from your request."
)
if conversation is not None: if conversation is not None:
if previous_response_id is not None: if previous_response_id is not None:
raise ValueError( raise ValueError(

View file

@ -66,6 +66,7 @@ from llama_stack_api import (
OpenAIResponseUsage, OpenAIResponseUsage,
OpenAIResponseUsageInputTokensDetails, OpenAIResponseUsageInputTokensDetails,
OpenAIResponseUsageOutputTokensDetails, OpenAIResponseUsageOutputTokensDetails,
Safety,
WebSearchToolTypes, WebSearchToolTypes,
) )
@ -111,7 +112,7 @@ class StreamingResponseOrchestrator:
max_infer_iters: int, max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class tool_executor, # Will be the tool execution logic from the main class
instructions: str | None, instructions: str | None,
safety_api, safety_api: Safety | None,
guardrail_ids: list[str] | None = None, guardrail_ids: list[str] | None = None,
prompt: OpenAIResponsePrompt | None = None, prompt: OpenAIResponsePrompt | None = None,
parallel_tool_calls: bool | None = None, parallel_tool_calls: bool | None = None,

View file

@ -320,11 +320,15 @@ def is_function_tool_call(
return False return False
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None: async def run_guardrails(safety_api: Safety | None, messages: str, guardrail_ids: list[str]) -> str | None:
"""Run guardrails against messages and return violation message if blocked.""" """Run guardrails against messages and return violation message if blocked."""
if not messages: if not messages:
return None return None
# If safety API is not available, skip guardrails
if safety_api is None:
return None
# Look up shields to get their provider_resource_id (actual model ID) # Look up shields to get their provider_resource_id (actual model ID)
model_ids = [] model_ids = []
# TODO: list_shields not in Safety interface but available at runtime via API routing # TODO: list_shields not in Safety interface but available at runtime via API routing

View file

@ -30,12 +30,14 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig", config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
api_dependencies=[ api_dependencies=[
Api.inference, Api.inference,
Api.safety,
Api.vector_io, Api.vector_io,
Api.tool_runtime, Api.tool_runtime,
Api.tool_groups, Api.tool_groups,
Api.conversations, Api.conversations,
], ],
optional_api_dependencies=[
Api.safety,
],
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.", description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
), ),
] ]

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from urllib.parse import urljoin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig from .config import AzureConfig
@ -22,4 +20,4 @@ class AzureInferenceAdapter(OpenAIMixin):
Returns the Azure API base URL from the configuration. Returns the Azure API base URL from the configuration.
""" """
return urljoin(str(self.config.api_base), "/openai/v1") return str(self.config.base_url)

View file

@ -32,8 +32,9 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class AzureConfig(RemoteInferenceProviderConfig): class AzureConfig(RemoteInferenceProviderConfig):
api_base: HttpUrl = Field( base_url: HttpUrl | None = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)", default=None,
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com/openai/v1)",
) )
api_version: str | None = Field( api_version: str | None = Field(
default_factory=lambda: os.getenv("AZURE_API_VERSION"), default_factory=lambda: os.getenv("AZURE_API_VERSION"),
@ -48,14 +49,14 @@ class AzureConfig(RemoteInferenceProviderConfig):
def sample_run_config( def sample_run_config(
cls, cls,
api_key: str = "${env.AZURE_API_KEY:=}", api_key: str = "${env.AZURE_API_KEY:=}",
api_base: str = "${env.AZURE_API_BASE:=}", base_url: str = "${env.AZURE_API_BASE:=}",
api_version: str = "${env.AZURE_API_VERSION:=}", api_version: str = "${env.AZURE_API_VERSION:=}",
api_type: str = "${env.AZURE_API_TYPE:=}", api_type: str = "${env.AZURE_API_TYPE:=}",
**kwargs, **kwargs,
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {
"api_key": api_key, "api_key": api_key,
"api_base": api_base, "base_url": base_url,
"api_version": api_version, "api_version": api_version,
"api_type": api_type, "api_type": api_type,
} }

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from urllib.parse import urljoin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import ( from llama_stack_api import (
OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsRequestWithExtraBody,
@ -21,7 +19,7 @@ class CerebrasInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "cerebras_api_key" provider_data_api_key_field: str = "cerebras_api_key"
def get_base_url(self) -> str: def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1") return str(self.config.base_url)
async def openai_embeddings( async def openai_embeddings(
self, self,

View file

@ -7,12 +7,12 @@
import os import os
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai" DEFAULT_BASE_URL = "https://api.cerebras.ai/v1"
class CerebrasProviderDataValidator(BaseModel): class CerebrasProviderDataValidator(BaseModel):
@ -24,8 +24,8 @@ class CerebrasProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class CerebrasImplConfig(RemoteInferenceProviderConfig): class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field( base_url: HttpUrl | None = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL), default=HttpUrl(os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL)),
description="Base URL for the Cerebras API", description="Base URL for the Cerebras API",
) )

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -21,9 +21,9 @@ class DatabricksProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig): class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field( base_url: HttpUrl | None = Field(
default=None, default=None,
description="The URL for the Databricks model serving endpoint", description="The URL for the Databricks model serving endpoint (should include /serving-endpoints path)",
) )
auth_credential: SecretStr | None = Field( auth_credential: SecretStr | None = Field(
default=None, default=None,
@ -34,11 +34,11 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,
url: str = "${env.DATABRICKS_HOST:=}", base_url: str = "${env.DATABRICKS_HOST:=}",
api_token: str = "${env.DATABRICKS_TOKEN:=}", api_token: str = "${env.DATABRICKS_TOKEN:=}",
**kwargs: Any, **kwargs: Any,
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {
"url": url, "base_url": base_url,
"api_token": api_token, "api_token": api_token,
} }

View file

@ -29,15 +29,21 @@ class DatabricksInferenceAdapter(OpenAIMixin):
} }
def get_base_url(self) -> str: def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints" return str(self.config.base_url)
async def list_provider_model_ids(self) -> Iterable[str]: async def list_provider_model_ids(self) -> Iterable[str]:
# Filter out None values from endpoint names # Filter out None values from endpoint names
api_token = self._get_api_key_from_config_or_provider_data() api_token = self._get_api_key_from_config_or_provider_data()
# WorkspaceClient expects base host without /serving-endpoints suffix
base_url_str = str(self.config.base_url)
if base_url_str.endswith("/serving-endpoints"):
host = base_url_str[:-18] # Remove '/serving-endpoints'
else:
host = base_url_str
return [ return [
endpoint.name # type: ignore[misc] endpoint.name # type: ignore[misc]
for endpoint in WorkspaceClient( for endpoint in WorkspaceClient(
host=self.config.url, token=api_token host=host, token=api_token
).serving_endpoints.list() # TODO: this is not async ).serving_endpoints.list() # TODO: this is not async
] ]

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import Field from pydantic import Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
@json_schema_type @json_schema_type
class FireworksImplConfig(RemoteInferenceProviderConfig): class FireworksImplConfig(RemoteInferenceProviderConfig):
url: str = Field( base_url: HttpUrl | None = Field(
default="https://api.fireworks.ai/inference/v1", default=HttpUrl("https://api.fireworks.ai/inference/v1"),
description="The URL for the Fireworks server", description="The URL for the Fireworks server",
) )
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
return { return {
"url": "https://api.fireworks.ai/inference/v1", "base_url": "https://api.fireworks.ai/inference/v1",
"api_key": api_key, "api_key": api_key,
} }

View file

@ -24,4 +24,4 @@ class FireworksInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "fireworks_api_key" provider_data_api_key_field: str = "fireworks_api_key"
def get_base_url(self) -> str: def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1" return str(self.config.base_url)

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -21,14 +21,14 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class GroqConfig(RemoteInferenceProviderConfig): class GroqConfig(RemoteInferenceProviderConfig):
url: str = Field( base_url: HttpUrl | None = Field(
default="https://api.groq.com", default=HttpUrl("https://api.groq.com/openai/v1"),
description="The URL for the Groq AI server", description="The URL for the Groq AI server",
) )
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
return { return {
"url": "https://api.groq.com", "base_url": "https://api.groq.com/openai/v1",
"api_key": api_key, "api_key": api_key,
} }

View file

@ -15,4 +15,4 @@ class GroqInferenceAdapter(OpenAIMixin):
provider_data_api_key_field: str = "groq_api_key" provider_data_api_key_field: str = "groq_api_key"
def get_base_url(self) -> str: def get_base_url(self) -> str:
return f"{self.config.url}/openai/v1" return str(self.config.base_url)

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -21,14 +21,14 @@ class LlamaProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class LlamaCompatConfig(RemoteInferenceProviderConfig): class LlamaCompatConfig(RemoteInferenceProviderConfig):
openai_compat_api_base: str = Field( base_url: HttpUrl | None = Field(
default="https://api.llama.com/compat/v1/", default=HttpUrl("https://api.llama.com/compat/v1/"),
description="The URL for the Llama API server", description="The URL for the Llama API server",
) )
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
return { return {
"openai_compat_api_base": "https://api.llama.com/compat/v1/", "base_url": "https://api.llama.com/compat/v1/",
"api_key": api_key, "api_key": api_key,
} }

View file

@ -31,7 +31,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
:return: The Llama API base URL :return: The Llama API base URL
""" """
return self.config.openai_compat_api_base return str(self.config.base_url)
async def openai_completion( async def openai_completion(
self, self,

View file

@ -7,7 +7,7 @@
import os import os
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -44,18 +44,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
URL of your running NVIDIA NIM and do not need to set the api_key. URL of your running NVIDIA NIM and do not need to set the api_key.
""" """
url: str = Field( base_url: HttpUrl | None = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"), default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"),
description="A base url for accessing the NVIDIA NIM", description="A base url for accessing the NVIDIA NIM",
) )
timeout: int = Field( timeout: int = Field(
default=60, default=60,
description="Timeout for the HTTP requests", description="Timeout for the HTTP requests",
) )
append_api_version: bool = Field(
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
rerank_model_to_url: dict[str, str] = Field( rerank_model_to_url: dict[str, str] = Field(
default_factory=lambda: { default_factory=lambda: {
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking", "nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
@ -68,13 +64,11 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,
url: str = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}", base_url: HttpUrl | None = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}",
api_key: str = "${env.NVIDIA_API_KEY:=}", api_key: str = "${env.NVIDIA_API_KEY:=}",
append_api_version: bool = "${env.NVIDIA_APPEND_API_VERSION:=True}",
**kwargs, **kwargs,
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {
"url": url, "base_url": base_url,
"api_key": api_key, "api_key": api_key,
"append_api_version": append_api_version,
} }

View file

@ -44,7 +44,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
} }
async def initialize(self) -> None: async def initialize(self) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...") logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.base_url})...")
if _is_nvidia_hosted(self.config): if _is_nvidia_hosted(self.config):
if not self.config.auth_credential: if not self.config.auth_credential:
@ -72,7 +72,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API base URL :return: The NVIDIA API base URL
""" """
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url return str(self.config.base_url)
async def list_provider_model_ids(self) -> Iterable[str]: async def list_provider_model_ids(self) -> Iterable[str]:
""" """

View file

@ -8,4 +8,4 @@ from . import NVIDIAConfig
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
return "integrate.api.nvidia.com" in config.url return "integrate.api.nvidia.com" in str(config.base_url)

View file

@ -6,20 +6,22 @@
from typing import Any from typing import Any
from pydantic import Field, SecretStr from pydantic import Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
DEFAULT_OLLAMA_URL = "http://localhost:11434" DEFAULT_OLLAMA_URL = "http://localhost:11434/v1"
class OllamaImplConfig(RemoteInferenceProviderConfig): class OllamaImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True) auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = DEFAULT_OLLAMA_URL base_url: HttpUrl | None = Field(default=HttpUrl(DEFAULT_OLLAMA_URL))
@classmethod @classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: def sample_run_config(
cls, base_url: str = "${env.OLLAMA_URL:=http://localhost:11434/v1}", **kwargs
) -> dict[str, Any]:
return { return {
"url": url, "base_url": base_url,
} }

View file

@ -55,17 +55,23 @@ class OllamaInferenceAdapter(OpenAIMixin):
# ollama client attaches itself to the current event loop (sadly?) # ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
if loop not in self._clients: if loop not in self._clients:
self._clients[loop] = AsyncOllamaClient(host=self.config.url) # Ollama client expects base URL without /v1 suffix
base_url_str = str(self.config.base_url)
if base_url_str.endswith("/v1"):
host = base_url_str[:-3]
else:
host = base_url_str
self._clients[loop] = AsyncOllamaClient(host=host)
return self._clients[loop] return self._clients[loop]
def get_api_key(self): def get_api_key(self):
return "NO KEY REQUIRED" return "NO KEY REQUIRED"
def get_base_url(self): def get_base_url(self):
return self.config.url.rstrip("/") + "/v1" return str(self.config.base_url)
async def initialize(self) -> None: async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...") logger.info(f"checking connectivity to Ollama at `{self.config.base_url}`...")
r = await self.health() r = await self.health()
if r["status"] == HealthStatus.ERROR: if r["status"] == HealthStatus.ERROR:
logger.warning( logger.warning(

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -21,8 +21,8 @@ class OpenAIProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class OpenAIConfig(RemoteInferenceProviderConfig): class OpenAIConfig(RemoteInferenceProviderConfig):
base_url: str = Field( base_url: HttpUrl | None = Field(
default="https://api.openai.com/v1", default=HttpUrl("https://api.openai.com/v1"),
description="Base URL for OpenAI API", description="Base URL for OpenAI API",
) )

View file

@ -35,4 +35,4 @@ class OpenAIInferenceAdapter(OpenAIMixin):
Returns the OpenAI API base URL from the configuration. Returns the OpenAI API base URL from the configuration.
""" """
return self.config.base_url return str(self.config.base_url)

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import Field from pydantic import Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -14,16 +14,16 @@ from llama_stack_api import json_schema_type
@json_schema_type @json_schema_type
class PassthroughImplConfig(RemoteInferenceProviderConfig): class PassthroughImplConfig(RemoteInferenceProviderConfig):
url: str = Field( base_url: HttpUrl | None = Field(
default=None, default=None,
description="The URL for the passthrough endpoint", description="The URL for the passthrough endpoint",
) )
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs cls, base_url: HttpUrl | None = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {
"url": url, "base_url": base_url,
"api_key": api_key, "api_key": api_key,
} }

View file

@ -82,8 +82,8 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
def _get_passthrough_url(self) -> str: def _get_passthrough_url(self) -> str:
"""Get the passthrough URL from config or provider data.""" """Get the passthrough URL from config or provider data."""
if self.config.url is not None: if self.config.base_url is not None:
return self.config.url return str(self.config.base_url)
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None: if provider_data is None:

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -21,7 +21,7 @@ class RunpodProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class RunpodImplConfig(RemoteInferenceProviderConfig): class RunpodImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field( base_url: HttpUrl | None = Field(
default=None, default=None,
description="The URL for the Runpod model serving endpoint", description="The URL for the Runpod model serving endpoint",
) )
@ -34,6 +34,6 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
@classmethod @classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return { return {
"url": "${env.RUNPOD_URL:=}", "base_url": "${env.RUNPOD_URL:=}",
"api_token": "${env.RUNPOD_API_TOKEN}", "api_token": "${env.RUNPOD_API_TOKEN}",
} }

View file

@ -28,7 +28,7 @@ class RunpodInferenceAdapter(OpenAIMixin):
def get_base_url(self) -> str: def get_base_url(self) -> str:
"""Get base URL for OpenAI client.""" """Get base URL for OpenAI client."""
return self.config.url return str(self.config.base_url)
async def openai_chat_completion( async def openai_chat_completion(
self, self,

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -21,14 +21,14 @@ class SambaNovaProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class SambaNovaImplConfig(RemoteInferenceProviderConfig): class SambaNovaImplConfig(RemoteInferenceProviderConfig):
url: str = Field( base_url: HttpUrl | None = Field(
default="https://api.sambanova.ai/v1", default=HttpUrl("https://api.sambanova.ai/v1"),
description="The URL for the SambaNova AI server", description="The URL for the SambaNova AI server",
) )
@classmethod @classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
return { return {
"url": "https://api.sambanova.ai/v1", "base_url": "https://api.sambanova.ai/v1",
"api_key": api_key, "api_key": api_key,
} }

View file

@ -25,4 +25,4 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
:return: The SambaNova base URL :return: The SambaNova base URL
""" """
return self.config.url return str(self.config.base_url)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -15,18 +15,19 @@ from llama_stack_api import json_schema_type
class TGIImplConfig(RemoteInferenceProviderConfig): class TGIImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True) auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = Field( base_url: HttpUrl | None = Field(
description="The URL for the TGI serving endpoint", default=None,
description="The URL for the TGI serving endpoint (should include /v1 path)",
) )
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,
url: str = "${env.TGI_URL:=}", base_url: str = "${env.TGI_URL:=}",
**kwargs, **kwargs,
): ):
return { return {
"url": url, "base_url": base_url,
} }

View file

@ -8,7 +8,7 @@
from collections.abc import Iterable from collections.abc import Iterable
from huggingface_hub import AsyncInferenceClient, HfApi from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr from pydantic import HttpUrl, SecretStr
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -23,7 +23,7 @@ log = get_logger(name=__name__, category="inference::tgi")
class _HfAdapter(OpenAIMixin): class _HfAdapter(OpenAIMixin):
url: str base_url: HttpUrl
api_key: SecretStr api_key: SecretStr
hf_client: AsyncInferenceClient hf_client: AsyncInferenceClient
@ -36,7 +36,7 @@ class _HfAdapter(OpenAIMixin):
return "NO KEY REQUIRED" return "NO KEY REQUIRED"
def get_base_url(self): def get_base_url(self):
return self.url return self.base_url
async def list_provider_model_ids(self) -> Iterable[str]: async def list_provider_model_ids(self) -> Iterable[str]:
return [self.model_id] return [self.model_id]
@ -50,14 +50,20 @@ class _HfAdapter(OpenAIMixin):
class TGIAdapter(_HfAdapter): class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None: async def initialize(self, config: TGIImplConfig) -> None:
if not config.url: if not config.base_url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.") raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}") log.info(f"Initializing TGI client with url={config.base_url}")
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference") # Extract base URL without /v1 for HF client initialization
base_url_str = str(config.base_url).rstrip("/")
if base_url_str.endswith("/v1"):
base_url_for_client = base_url_str[:-3]
else:
base_url_for_client = base_url_str
self.hf_client = AsyncInferenceClient(model=base_url_for_client, provider="hf-inference")
endpoint_info = await self.hf_client.get_endpoint_info() endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"] self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"] self.model_id = endpoint_info["model_id"]
self.url = f"{config.url.rstrip('/')}/v1" self.base_url = config.base_url
self.api_key = SecretStr("NO_KEY") self.api_key = SecretStr("NO_KEY")

View file

@ -6,7 +6,7 @@
from typing import Any from typing import Any
from pydantic import Field from pydantic import Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
@json_schema_type @json_schema_type
class TogetherImplConfig(RemoteInferenceProviderConfig): class TogetherImplConfig(RemoteInferenceProviderConfig):
url: str = Field( base_url: HttpUrl | None = Field(
default="https://api.together.xyz/v1", default=HttpUrl("https://api.together.xyz/v1"),
description="The URL for the Together AI server", description="The URL for the Together AI server",
) )
@classmethod @classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]: def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return { return {
"url": "https://api.together.xyz/v1", "base_url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY:=}", "api_key": "${env.TOGETHER_API_KEY:=}",
} }

View file

@ -9,7 +9,6 @@ from collections.abc import Iterable
from typing import Any, cast from typing import Any, cast
from together import AsyncTogether # type: ignore[import-untyped] from together import AsyncTogether # type: ignore[import-untyped]
from together.constants import BASE_URL # type: ignore[import-untyped]
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -42,7 +41,7 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
provider_data_api_key_field: str = "together_api_key" provider_data_api_key_field: str = "together_api_key"
def get_base_url(self): def get_base_url(self):
return BASE_URL return str(self.config.base_url)
def _get_client(self) -> AsyncTogether: def _get_client(self) -> AsyncTogether:
together_api_key = None together_api_key = None

View file

@ -6,7 +6,7 @@
from pathlib import Path from pathlib import Path
from pydantic import Field, SecretStr, field_validator from pydantic import Field, HttpUrl, SecretStr, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -14,7 +14,7 @@ from llama_stack_api import json_schema_type
@json_schema_type @json_schema_type
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig): class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
url: str | None = Field( base_url: HttpUrl | None = Field(
default=None, default=None,
description="The URL for the vLLM model serving endpoint", description="The URL for the vLLM model serving endpoint",
) )
@ -48,11 +48,11 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,
url: str = "${env.VLLM_URL:=}", base_url: str = "${env.VLLM_URL:=}",
**kwargs, **kwargs,
): ):
return { return {
"url": url, "base_url": base_url,
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}", "max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
"api_token": "${env.VLLM_API_TOKEN:=fake}", "api_token": "${env.VLLM_API_TOKEN:=fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}", "tls_verify": "${env.VLLM_TLS_VERIFY:=true}",

View file

@ -39,12 +39,12 @@ class VLLMInferenceAdapter(OpenAIMixin):
def get_base_url(self) -> str: def get_base_url(self) -> str:
"""Get the base URL from config.""" """Get the base URL from config."""
if not self.config.url: if not self.config.base_url:
raise ValueError("No base URL configured") raise ValueError("No base URL configured")
return self.config.url return str(self.config.base_url)
async def initialize(self) -> None: async def initialize(self) -> None:
if not self.config.url: if not self.config.base_url:
raise ValueError( raise ValueError(
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM." "You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
) )

View file

@ -7,7 +7,7 @@
import os import os
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, HttpUrl
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack_api import json_schema_type from llama_stack_api import json_schema_type
@ -23,7 +23,7 @@ class WatsonXProviderDataValidator(BaseModel):
@json_schema_type @json_schema_type
class WatsonXConfig(RemoteInferenceProviderConfig): class WatsonXConfig(RemoteInferenceProviderConfig):
url: str = Field( base_url: HttpUrl | None = Field(
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"), default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai", description="A base url for accessing the watsonx.ai",
) )
@ -39,7 +39,7 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
@classmethod @classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]: def sample_run_config(cls, **kwargs) -> dict[str, Any]:
return { return {
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}", "base_url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:=}", "api_key": "${env.WATSONX_API_KEY:=}",
"project_id": "${env.WATSONX_PROJECT_ID:=}", "project_id": "${env.WATSONX_PROJECT_ID:=}",
} }

View file

@ -255,7 +255,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
) )
def get_base_url(self) -> str: def get_base_url(self) -> str:
return self.config.url return str(self.config.base_url)
# Copied from OpenAIMixin # Copied from OpenAIMixin
async def check_model_availability(self, model: str) -> bool: async def check_model_availability(self, model: str) -> bool:
@ -316,7 +316,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
""" """
Retrieves foundation model specifications from the watsonx.ai API. Retrieves foundation model specifications from the watsonx.ai API.
""" """
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25" url = f"{str(self.config.base_url)}/ml/v1/foundation_model_specs?version=2023-10-25"
headers = { headers = {
# Note that there is no authorization header. Listing models does not require authentication. # Note that there is no authorization header. Listing models does not require authentication.
"Content-Type": "application/json", "Content-Type": "application/json",

View file

@ -211,3 +211,23 @@ def test_asymmetric_embeddings(llama_stack_client, embedding_model_id):
assert query_response.embeddings is not None assert query_response.embeddings is not None
``` ```
## TypeScript Client Replays
TypeScript SDK tests can run alongside Python tests when testing against `server:<config>` stacks. Set `TS_CLIENT_PATH` to the path or version of `llama-stack-client-typescript` to enable:
```bash
# Use published npm package (responses suite)
TS_CLIENT_PATH=^0.3.2 scripts/integration-tests.sh --stack-config server:ci-tests --suite responses --setup gpt
# Use local checkout from ~/.cache (recommended for development)
git clone https://github.com/llamastack/llama-stack-client-typescript.git ~/.cache/llama-stack-client-typescript
TS_CLIENT_PATH=~/.cache/llama-stack-client-typescript scripts/integration-tests.sh --stack-config server:ci-tests --suite responses --setup gpt
# Run base suite with TypeScript tests
TS_CLIENT_PATH=~/.cache/llama-stack-client-typescript scripts/integration-tests.sh --stack-config server:ci-tests --suite base --setup ollama
```
TypeScript tests run immediately after Python tests pass, using the same replay fixtures. The mapping between Python suites/setups and TypeScript test files is defined in `tests/integration/client-typescript/suites.json`.
If `TS_CLIENT_PATH` is unset, TypeScript tests are skipped entirely.

View file

@ -0,0 +1,104 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the terms described in the LICENSE file in
// the root directory of this source tree.
/**
* Integration tests for Inference API (Chat Completions).
* Ported from: llama-stack/tests/integration/inference/test_openai_completion.py
*
* IMPORTANT: Test cases must match EXACTLY with Python tests to use recorded API responses.
*/
import { createTestClient, requireTextModel } from '../setup';
describe('Inference API - Chat Completions', () => {
// Test cases matching llama-stack/tests/integration/test_cases/inference/chat_completion.json
const chatCompletionTestCases = [
{
id: 'non_streaming_01',
question: 'Which planet do humans live on?',
expected: 'earth',
testId:
'tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-inference:chat_completion:non_streaming_01]',
},
{
id: 'non_streaming_02',
question: 'Which planet has rings around it with a name starting with letter S?',
expected: 'saturn',
testId:
'tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_non_streaming[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-inference:chat_completion:non_streaming_02]',
},
];
const streamingTestCases = [
{
id: 'streaming_01',
question: "What's the name of the Sun in latin?",
expected: 'sol',
testId:
'tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-inference:chat_completion:streaming_01]',
},
{
id: 'streaming_02',
question: 'What is the name of the US captial?',
expected: 'washington',
testId:
'tests/integration/inference/test_openai_completion.py::test_openai_chat_completion_streaming[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-inference:chat_completion:streaming_02]',
},
];
test.each(chatCompletionTestCases)(
'chat completion non-streaming: $id',
async ({ question, expected, testId }) => {
const client = createTestClient(testId);
const textModel = requireTextModel();
const response = await client.chat.completions.create({
model: textModel,
messages: [
{
role: 'user',
content: question,
},
],
stream: false,
});
// Non-streaming responses have choices with message property
const choice = response.choices[0];
expect(choice).toBeDefined();
if (!choice || !('message' in choice)) {
throw new Error('Expected non-streaming response with message');
}
const content = choice.message.content;
expect(content).toBeDefined();
const messageContent = typeof content === 'string' ? content.toLowerCase().trim() : '';
expect(messageContent.length).toBeGreaterThan(0);
expect(messageContent).toContain(expected.toLowerCase());
},
);
test.each(streamingTestCases)('chat completion streaming: $id', async ({ question, expected, testId }) => {
const client = createTestClient(testId);
const textModel = requireTextModel();
const stream = await client.chat.completions.create({
model: textModel,
messages: [{ role: 'user', content: question }],
stream: true,
});
const streamedContent: string[] = [];
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0 && chunk.choices[0]?.delta?.content) {
streamedContent.push(chunk.choices[0].delta.content);
}
}
expect(streamedContent.length).toBeGreaterThan(0);
const fullContent = streamedContent.join('').toLowerCase().trim();
expect(fullContent).toContain(expected.toLowerCase());
});
});

View file

@ -0,0 +1,132 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the terms described in the LICENSE file in
// the root directory of this source tree.
/**
* Integration tests for Responses API.
* Ported from: llama-stack/tests/integration/responses/test_basic_responses.py
*
* IMPORTANT: Test cases and IDs must match EXACTLY with Python tests to use recorded API responses.
*/
import { createTestClient, requireTextModel, getResponseOutputText } from '../setup';
describe('Responses API - Basic', () => {
// Test cases matching llama-stack/tests/integration/responses/fixtures/test_cases.py
const basicTestCases = [
{
id: 'earth',
input: 'Which planet do humans live on?',
expected: 'earth',
// Use client_with_models fixture to match non-streaming recordings
testId:
'tests/integration/responses/test_basic_responses.py::test_response_non_streaming_basic[client_with_models-txt=openai/gpt-4o-earth]',
},
{
id: 'saturn',
input: 'Which planet has rings around it with a name starting with letter S?',
expected: 'saturn',
testId:
'tests/integration/responses/test_basic_responses.py::test_response_non_streaming_basic[client_with_models-txt=openai/gpt-4o-saturn]',
},
];
test.each(basicTestCases)('non-streaming basic response: $id', async ({ input, expected, testId }) => {
// Create client with test_id for all requests
const client = createTestClient(testId);
const textModel = requireTextModel();
// Create a response
const response = await client.responses.create({
model: textModel,
input,
stream: false,
});
// Verify response has content
const outputText = getResponseOutputText(response).toLowerCase().trim();
expect(outputText.length).toBeGreaterThan(0);
expect(outputText).toContain(expected.toLowerCase());
// Verify usage is reported
expect(response.usage).toBeDefined();
expect(response.usage!.input_tokens).toBeGreaterThan(0);
expect(response.usage!.output_tokens).toBeGreaterThan(0);
expect(response.usage!.total_tokens).toBe(response.usage!.input_tokens + response.usage!.output_tokens);
// Verify stored response matches
const retrievedResponse = await client.responses.retrieve(response.id);
expect(getResponseOutputText(retrievedResponse)).toBe(getResponseOutputText(response));
// Test follow-up with previous_response_id
const nextResponse = await client.responses.create({
model: textModel,
input: 'Repeat your previous response in all caps.',
previous_response_id: response.id,
});
const nextOutputText = getResponseOutputText(nextResponse).trim();
expect(nextOutputText).toContain(expected.toUpperCase());
});
test.each(basicTestCases)('streaming basic response: $id', async ({ input, expected, testId }) => {
// Modify test_id for streaming variant
const streamingTestId = testId.replace(
'test_response_non_streaming_basic',
'test_response_streaming_basic',
);
const client = createTestClient(streamingTestId);
const textModel = requireTextModel();
// Create a streaming response
const stream = await client.responses.create({
model: textModel,
input,
stream: true,
});
const events: any[] = [];
let responseId = '';
for await (const chunk of stream) {
events.push(chunk);
if (chunk.type === 'response.created') {
// Verify response.created is the first event
expect(events.length).toBe(1);
expect(chunk.response.status).toBe('in_progress');
responseId = chunk.response.id;
} else if (chunk.type === 'response.completed') {
// Verify response.completed comes after response.created
expect(events.length).toBeGreaterThanOrEqual(2);
expect(chunk.response.status).toBe('completed');
expect(chunk.response.id).toBe(responseId);
// Verify content quality
const outputText = getResponseOutputText(chunk.response).toLowerCase().trim();
expect(outputText.length).toBeGreaterThan(0);
expect(outputText).toContain(expected.toLowerCase());
// Verify usage is reported
expect(chunk.response.usage).toBeDefined();
expect(chunk.response.usage!.input_tokens).toBeGreaterThan(0);
expect(chunk.response.usage!.output_tokens).toBeGreaterThan(0);
expect(chunk.response.usage!.total_tokens).toBe(
chunk.response.usage!.input_tokens + chunk.response.usage!.output_tokens,
);
}
}
// Verify we got both events
expect(events.length).toBeGreaterThanOrEqual(2);
const firstEvent = events[0];
const lastEvent = events[events.length - 1];
expect(firstEvent.type).toBe('response.created');
expect(lastEvent.type).toBe('response.completed');
// Verify stored response matches streamed response
const retrievedResponse = await client.responses.retrieve(responseId);
expect(getResponseOutputText(retrievedResponse)).toBe(getResponseOutputText(lastEvent.response));
});
});

View file

@ -0,0 +1,31 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the terms described in the LICENSE file in
// the root directory of this source tree.
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest/presets/default-esm',
testEnvironment: 'node',
extensionsToTreatAsEsm: ['.ts'],
moduleNameMapper: {
'^(\\.{1,2}/.*)\\.js$': '$1',
},
transform: {
'^.+\\.tsx?$': [
'ts-jest',
{
useESM: true,
tsconfig: {
module: 'ES2022',
moduleResolution: 'bundler',
},
},
],
},
testMatch: ['<rootDir>/__tests__/**/*.test.ts'],
setupFilesAfterEnv: ['<rootDir>/setup.ts'],
testTimeout: 60000, // 60 seconds (integration tests can be slow)
watchman: false, // Disable watchman to avoid permission issues
};

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,18 @@
{
"name": "llama-stack-typescript-integration-tests",
"version": "0.0.1",
"private": true,
"description": "TypeScript client integration tests for Llama Stack",
"scripts": {
"test": "node run-tests.js"
},
"devDependencies": {
"@swc/core": "^1.3.102",
"@swc/jest": "^0.2.29",
"@types/jest": "^29.4.0",
"@types/node": "^20.0.0",
"jest": "^29.4.0",
"ts-jest": "^29.1.0",
"typescript": "^5.0.0"
}
}

View file

@ -0,0 +1,63 @@
#!/usr/bin/env node
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the terms described in the LICENSE file in
// the root directory of this source tree.
/**
* Test runner that finds and executes TypeScript tests based on suite/setup mapping.
* Called by integration-tests.sh via npm test.
*/
const fs = require('fs');
const path = require('path');
const { execSync } = require('child_process');
const suite = process.env.LLAMA_STACK_TEST_SUITE;
const setup = process.env.LLAMA_STACK_TEST_SETUP || '';
if (!suite) {
console.error('Error: LLAMA_STACK_TEST_SUITE environment variable is required');
process.exit(1);
}
// Read suites.json to find matching test files
const suitesPath = path.join(__dirname, 'suites.json');
if (!fs.existsSync(suitesPath)) {
console.log(`No TypeScript tests configured (${suitesPath} not found)`);
process.exit(0);
}
const suites = JSON.parse(fs.readFileSync(suitesPath, 'utf-8'));
// Find matching entry
let testFiles = [];
for (const entry of suites) {
if (entry.suite !== suite) {
continue;
}
const entrySetup = entry.setup || '';
if (entrySetup && entrySetup !== setup) {
continue;
}
testFiles = entry.files || [];
break;
}
if (testFiles.length === 0) {
console.log(`No TypeScript integration tests mapped for suite ${suite} (setup ${setup})`);
process.exit(0);
}
console.log(`Running TypeScript tests for suite ${suite} (setup ${setup}): ${testFiles.join(', ')}`);
// Run Jest with the mapped test files
try {
execSync(`npx jest --config jest.integration.config.js ${testFiles.join(' ')}`, {
stdio: 'inherit',
cwd: __dirname,
});
} catch (error) {
process.exit(error.status || 1);
}

View file

@ -0,0 +1,162 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the terms described in the LICENSE file in
// the root directory of this source tree.
/**
* Global setup for integration tests.
* This file mimics pytest's fixture system by providing shared test configuration.
*/
import LlamaStackClient from 'llama-stack-client';
/**
* Load test configuration from the Python setup system.
* This reads setup definitions from tests/integration/suites.py via get_setup_env.py.
*/
function loadTestConfig() {
const baseURL = process.env['TEST_API_BASE_URL'];
const setupName = process.env['LLAMA_STACK_TEST_SETUP'];
const textModel = process.env['LLAMA_STACK_TEST_TEXT_MODEL'];
const embeddingModel = process.env['LLAMA_STACK_TEST_EMBEDDING_MODEL'];
if (!baseURL) {
throw new Error(
'TEST_API_BASE_URL is required for integration tests. ' +
'Run tests using: ./scripts/integration-test.sh',
);
}
return {
baseURL,
textModel,
embeddingModel,
setupName,
};
}
// Read configuration from environment variables (set by scripts/integration-test.sh)
export const TEST_CONFIG = loadTestConfig();
// Validate required configuration
beforeAll(() => {
console.log('\n=== Integration Test Configuration ===');
console.log(`Base URL: ${TEST_CONFIG.baseURL}`);
console.log(`Setup: ${TEST_CONFIG.setupName || 'NOT SET'}`);
console.log(
`Text Model: ${TEST_CONFIG.textModel || 'NOT SET - tests requiring text model will be skipped'}`,
);
console.log(
`Embedding Model: ${
TEST_CONFIG.embeddingModel || 'NOT SET - tests requiring embedding model will be skipped'
}`,
);
console.log('=====================================\n');
});
/**
* Create a client instance for integration tests.
* Mimics pytest's `llama_stack_client` fixture.
*
* @param testId - Test ID to send in X-LlamaStack-Provider-Data header for replay mode.
* Format: "tests/integration/responses/test_basic_responses.py::test_name[params]"
*/
export function createTestClient(testId?: string): LlamaStackClient {
const headers: Record<string, string> = {};
// In server mode with replay, send test ID for recording isolation
if (process.env['LLAMA_STACK_TEST_STACK_CONFIG_TYPE'] === 'server' && testId) {
headers['X-LlamaStack-Provider-Data'] = JSON.stringify({
__test_id: testId,
});
}
return new LlamaStackClient({
baseURL: TEST_CONFIG.baseURL,
timeout: 60000, // 60 seconds
defaultHeaders: headers,
});
}
/**
* Skip test if required model is not configured.
* Mimics pytest's `skip_if_no_model` autouse fixture.
*/
export function skipIfNoModel(modelType: 'text' | 'embedding'): typeof test {
const model = modelType === 'text' ? TEST_CONFIG.textModel : TEST_CONFIG.embeddingModel;
if (!model) {
const envVar = modelType === 'text' ? 'LLAMA_STACK_TEST_TEXT_MODEL' : 'LLAMA_STACK_TEST_EMBEDDING_MODEL';
const message = `Skipping: ${modelType} model not configured (set ${envVar})`;
return test.skip.bind(test) as typeof test;
}
return test;
}
/**
* Get the configured text model, throwing if not set.
* Use this in tests that absolutely require a text model.
*/
export function requireTextModel(): string {
if (!TEST_CONFIG.textModel) {
throw new Error(
'LLAMA_STACK_TEST_TEXT_MODEL environment variable is required. ' +
'Run tests using: ./scripts/integration-test.sh',
);
}
return TEST_CONFIG.textModel;
}
/**
* Get the configured embedding model, throwing if not set.
* Use this in tests that absolutely require an embedding model.
*/
export function requireEmbeddingModel(): string {
if (!TEST_CONFIG.embeddingModel) {
throw new Error(
'LLAMA_STACK_TEST_EMBEDDING_MODEL environment variable is required. ' +
'Run tests using: ./scripts/integration-test.sh',
);
}
return TEST_CONFIG.embeddingModel;
}
/**
* Extracts aggregated text output from a ResponseObject.
* This concatenates all text content from the response's output array.
*
* Copied from llama-stack-client's response-helpers until it's available in published version.
*/
export function getResponseOutputText(response: any): string {
const pieces: string[] = [];
for (const output of response.output ?? []) {
if (!output || output.type !== 'message') {
continue;
}
const content = output.content;
if (typeof content === 'string') {
pieces.push(content);
continue;
}
if (!Array.isArray(content)) {
continue;
}
for (const item of content) {
if (typeof item === 'string') {
pieces.push(item);
continue;
}
if (item && item.type === 'output_text' && 'text' in item && typeof item.text === 'string') {
pieces.push(item.text);
}
}
}
return pieces.join('');
}

View file

@ -0,0 +1,12 @@
[
{
"suite": "responses",
"setup": "gpt",
"files": ["__tests__/responses.test.ts"]
},
{
"suite": "base",
"setup": "ollama",
"files": ["__tests__/inference.test.ts"]
}
]

View file

@ -0,0 +1,16 @@
{
"compilerOptions": {
"target": "ES2022",
"module": "ES2022",
"lib": ["ES2022"],
"moduleResolution": "bundler",
"esModuleInterop": true,
"allowSyntheticDefaultImports": true,
"strict": true,
"skipLibCheck": true,
"resolveJsonModule": true,
"types": ["jest", "node"]
},
"include": ["**/*.ts"],
"exclude": ["node_modules"]
}

View file

@ -50,7 +50,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
name="ollama", name="ollama",
description="Local Ollama provider with text + safety models", description="Local Ollama provider with text + safety models",
env={ env={
"OLLAMA_URL": "http://0.0.0.0:11434", "OLLAMA_URL": "http://0.0.0.0:11434/v1",
"SAFETY_MODEL": "ollama/llama-guard3:1b", "SAFETY_MODEL": "ollama/llama-guard3:1b",
}, },
defaults={ defaults={
@ -64,7 +64,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
name="ollama", name="ollama",
description="Local Ollama provider with a vision model", description="Local Ollama provider with a vision model",
env={ env={
"OLLAMA_URL": "http://0.0.0.0:11434", "OLLAMA_URL": "http://0.0.0.0:11434/v1",
}, },
defaults={ defaults={
"vision_model": "ollama/llama3.2-vision:11b", "vision_model": "ollama/llama3.2-vision:11b",
@ -75,7 +75,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
name="ollama-postgres", name="ollama-postgres",
description="Server-mode tests with Postgres-backed persistence", description="Server-mode tests with Postgres-backed persistence",
env={ env={
"OLLAMA_URL": "http://0.0.0.0:11434", "OLLAMA_URL": "http://0.0.0.0:11434/v1",
"SAFETY_MODEL": "ollama/llama-guard3:1b", "SAFETY_MODEL": "ollama/llama-guard3:1b",
"POSTGRES_HOST": "127.0.0.1", "POSTGRES_HOST": "127.0.0.1",
"POSTGRES_PORT": "5432", "POSTGRES_PORT": "5432",

View file

@ -0,0 +1,206 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Tests for making Safety API optional in meta-reference agents provider.
This test suite validates the changes introduced to fix issue #4165, which
allows running the meta-reference agents provider without the Safety API.
Safety API is now an optional dependency, and errors are raised at request time
when guardrails are explicitly requested without Safety API configured.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.core.datatypes import Api
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
from llama_stack.providers.inline.agents.meta_reference import get_provider_impl
from llama_stack.providers.inline.agents.meta_reference.config import (
AgentPersistenceConfig,
MetaReferenceAgentsImplConfig,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
run_guardrails,
)
@pytest.fixture
def mock_persistence_config():
"""Create a mock persistence configuration."""
return AgentPersistenceConfig(
agent_state=KVStoreReference(
backend="kv_default",
namespace="agents",
),
responses=ResponsesStoreReference(
backend="sql_default",
table_name="responses",
),
)
@pytest.fixture
def mock_deps():
"""Create mock dependencies for the agents provider."""
# Create mock APIs
inference_api = AsyncMock()
vector_io_api = AsyncMock()
tool_runtime_api = AsyncMock()
tool_groups_api = AsyncMock()
conversations_api = AsyncMock()
return {
Api.inference: inference_api,
Api.vector_io: vector_io_api,
Api.tool_runtime: tool_runtime_api,
Api.tool_groups: tool_groups_api,
Api.conversations: conversations_api,
}
class TestProviderInitialization:
"""Test provider initialization with different safety API configurations."""
async def test_initialization_with_safety_api_present(self, mock_persistence_config, mock_deps):
"""Test successful initialization when Safety API is configured."""
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
# Add safety API to deps
safety_api = AsyncMock()
mock_deps[Api.safety] = safety_api
# Mock the initialize method to avoid actual initialization
with patch(
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
new_callable=AsyncMock,
):
# Should not raise any exception
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
assert provider is not None
async def test_initialization_without_safety_api(self, mock_persistence_config, mock_deps):
"""Test successful initialization when Safety API is not configured."""
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
# Safety API is NOT in mock_deps - provider should still start
# Mock the initialize method to avoid actual initialization
with patch(
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
new_callable=AsyncMock,
):
# Should not raise any exception
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
assert provider is not None
assert provider.safety_api is None
class TestGuardrailsFunctionality:
"""Test run_guardrails function with optional safety API."""
async def test_run_guardrails_with_none_safety_api(self):
"""Test that run_guardrails returns None when safety_api is None."""
result = await run_guardrails(safety_api=None, messages="test message", guardrail_ids=["llama-guard"])
assert result is None
async def test_run_guardrails_with_empty_messages(self):
"""Test that run_guardrails returns None for empty messages."""
# Test with None safety API
result = await run_guardrails(safety_api=None, messages="", guardrail_ids=["llama-guard"])
assert result is None
# Test with mock safety API
mock_safety_api = AsyncMock()
result = await run_guardrails(safety_api=mock_safety_api, messages="", guardrail_ids=["llama-guard"])
assert result is None
async def test_run_guardrails_with_none_safety_api_ignores_guardrails(self):
"""Test that guardrails are skipped when safety_api is None, even if guardrail_ids are provided."""
# Should not raise exception, just return None
result = await run_guardrails(
safety_api=None,
messages="potentially harmful content",
guardrail_ids=["llama-guard", "content-filter"],
)
assert result is None
async def test_create_response_rejects_guardrails_without_safety_api(self, mock_persistence_config, mock_deps):
"""Test that create_openai_response raises error when guardrails requested but Safety API unavailable."""
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack_api import ResponseGuardrailSpec
# Create OpenAIResponsesImpl with no safety API
with patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"):
impl = OpenAIResponsesImpl(
inference_api=mock_deps[Api.inference],
tool_groups_api=mock_deps[Api.tool_groups],
tool_runtime_api=mock_deps[Api.tool_runtime],
responses_store=MagicMock(),
vector_io_api=mock_deps[Api.vector_io],
safety_api=None, # No Safety API
conversations_api=mock_deps[Api.conversations],
)
# Test with string guardrail
with pytest.raises(ValueError) as exc_info:
await impl.create_openai_response(
input="test input",
model="test-model",
guardrails=["llama-guard"],
)
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
# Test with ResponseGuardrailSpec
with pytest.raises(ValueError) as exc_info:
await impl.create_openai_response(
input="test input",
model="test-model",
guardrails=[ResponseGuardrailSpec(type="llama-guard")],
)
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
async def test_create_response_succeeds_without_guardrails_and_no_safety_api(
self, mock_persistence_config, mock_deps
):
"""Test that create_openai_response works when no guardrails requested and Safety API unavailable."""
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
# Create OpenAIResponsesImpl with no safety API
with (
patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"),
patch.object(OpenAIResponsesImpl, "_create_streaming_response", new_callable=AsyncMock) as mock_stream,
):
# Mock the streaming response to return a simple async generator
async def mock_generator():
yield MagicMock()
mock_stream.return_value = mock_generator()
impl = OpenAIResponsesImpl(
inference_api=mock_deps[Api.inference],
tool_groups_api=mock_deps[Api.tool_groups],
tool_runtime_api=mock_deps[Api.tool_runtime],
responses_store=MagicMock(),
vector_io_api=mock_deps[Api.vector_io],
safety_api=None, # No Safety API
conversations_api=mock_deps[Api.conversations],
)
# Should not raise when no guardrails requested
# Note: This will still fail later in execution due to mocking, but should pass the validation
try:
await impl.create_openai_response(
input="test input",
model="test-model",
guardrails=None, # No guardrails
)
except Exception as e:
# Ensure the error is NOT about missing Safety API
assert "Cannot process guardrails: Safety API is not configured" not in str(e)

View file

@ -120,7 +120,7 @@ from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInfere
VLLMInferenceAdapter, VLLMInferenceAdapter,
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", "llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
{ {
"url": "http://fake", "base_url": "http://fake",
}, },
), ),
], ],
@ -153,7 +153,7 @@ def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_valid
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the """Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
assumption that there is an OpenAI-compatible client object.""" assumption that there is an OpenAI-compatible client object."""
inference_adapter = adapter_cls(config=config_cls()) inference_adapter = adapter_cls(config=config_cls(base_url="http://fake"))
inference_adapter.__provider_spec__ = MagicMock() inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator

View file

@ -40,7 +40,7 @@ from llama_stack_api import (
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def vllm_inference_adapter(): async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345") config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config=config) inference_adapter = VLLMInferenceAdapter(config=config)
inference_adapter.model_store = AsyncMock() inference_adapter.model_store = AsyncMock()
await inference_adapter.initialize() await inference_adapter.initialize()
@ -204,7 +204,7 @@ async def test_vllm_completion_extra_body():
via extra_body to the underlying OpenAI client through the InferenceRouter. via extra_body to the underlying OpenAI client through the InferenceRouter.
""" """
# Set up the vLLM adapter # Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345") config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config) vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm" vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize() await vllm_adapter.initialize()
@ -277,7 +277,7 @@ async def test_vllm_chat_completion_extra_body():
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion. via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
""" """
# Set up the vLLM adapter # Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345") config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config) vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm" vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize() await vllm_adapter.initialize()

View file

@ -146,7 +146,7 @@ async def test_hosted_model_not_in_endpoint_mapping():
async def test_self_hosted_ignores_endpoint(): async def test_self_hosted_ignores_endpoint():
adapter = create_adapter( adapter = create_adapter(
config=NVIDIAConfig(url="http://localhost:8000", api_key=None), config=NVIDIAConfig(base_url="http://localhost:8000", api_key=None),
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted. rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
) )
mock_session = MockSession(MockResponse()) mock_session = MockSession(MockResponse())

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import get_args, get_origin
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel, HttpUrl
from llama_stack.core.distribution import get_provider_registry, providable_apis from llama_stack.core.distribution import get_provider_registry, providable_apis
from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.dynamic import instantiate_class_type
@ -41,3 +43,55 @@ class TestProviderConfigurations:
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz") sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict" assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"
def test_remote_inference_url_standardization(self):
"""Verify all remote inference providers use standardized base_url configuration."""
provider_registry = get_provider_registry()
inference_providers = provider_registry.get("inference", {})
# Filter for remote providers only
remote_providers = {k: v for k, v in inference_providers.items() if k.startswith("remote::")}
failures = []
for provider_type, provider_spec in remote_providers.items():
try:
config_class_name = provider_spec.config_class
config_type = instantiate_class_type(config_class_name)
# Check that config has base_url field (not url)
if hasattr(config_type, "model_fields"):
fields = config_type.model_fields
# Should NOT have 'url' field (old pattern)
if "url" in fields:
failures.append(
f"{provider_type}: Uses deprecated 'url' field instead of 'base_url'. "
f"Please rename to 'base_url' for consistency."
)
# Should have 'base_url' field with HttpUrl | None type
if "base_url" in fields:
field_info = fields["base_url"]
annotation = field_info.annotation
# Check if it's HttpUrl or HttpUrl | None
# get_origin() returns Union for (X | Y), None for plain types
# get_args() returns the types inside Union, e.g. (HttpUrl, NoneType)
is_valid = False
if get_origin(annotation) is not None: # It's a Union/Optional
if HttpUrl in get_args(annotation):
is_valid = True
elif annotation == HttpUrl: # Plain HttpUrl without | None
is_valid = True
if not is_valid:
failures.append(
f"{provider_type}: base_url field has incorrect type annotation. "
f"Expected 'HttpUrl | None', got '{annotation}'"
)
except Exception as e:
failures.append(f"{provider_type}: Error checking URL standardization: {str(e)}")
if failures:
pytest.fail("URL standardization violations found:\n" + "\n".join(f" - {f}" for f in failures))